"""This module defines the `TensorMCMCSAEM` class."""
from random import shuffle
from leaspy.exceptions import LeaspyAlgoInputError
from leaspy.io.data import Dataset
from leaspy.models import McmcSaemCompatibleModel
from leaspy.variables.specs import (
IndividualLatentVariable,
LatentVariableInitType,
PopulationLatentVariable,
)
from leaspy.variables.state import State
from ..algo_with_annealing import AlgorithmWithAnnealingMixin
from ..algo_with_device import AlgorithmWithDeviceMixin
from ..algo_with_samplers import AlgorithmWithSamplersMixin
from ..base import AlgorithmName
from ..settings import AlgorithmSettings
from .base import FitAlgorithm
__all__ = ["TensorMcmcSaemAlgorithm"]
[docs]
class TensorMcmcSaemAlgorithm(
AlgorithmWithDeviceMixin,
AlgorithmWithAnnealingMixin,
AlgorithmWithSamplersMixin,
FitAlgorithm[McmcSaemCompatibleModel, State],
):
"""Main algorithm for MCMC-SAEM.
Parameters
----------
settings : :class:`~leaspy.algo.AlgorithmSettings`
MCMC fit algorithm settings
Attributes
----------
samplers : :obj:`dict` [:obj:`str`, :class:`~leaspy.samplers.AbstractSampler` ]
Dictionary of samplers per each variable
random_order_variables : :obj:`bool` (default True)
This attribute controls whether we randomize the order of variables at each iteration.
`Article <https://proceedings.neurips.cc/paper/2016/hash/e4da3b7fbbce2345d7772b0674a318d5-Abstract.html>`_
gives a reason on why we should activate this flag.
temperature : :obj:`float`
temperature_inv : :obj:`float`
Temperature and its inverse are modified during algorithm if annealing is used
See Also
--------
:mod:`leaspy.samplers`
"""
name: AlgorithmName = AlgorithmName.FIT_MCMC_SAEM
def __init__(self, settings: AlgorithmSettings):
super().__init__(settings)
if not (0.5 < self.algo_parameters["burn_in_step_power"] <= 1):
raise LeaspyAlgoInputError(
"The parameter `burn_in_step_power` should be in ]0.5, 1] in order to "
"have theoretical guarantees on convergence of MCMC-SAEM algorithm."
)
def _run(self, model: McmcSaemCompatibleModel, dataset: Dataset, **kwargs) -> State:
"""Main method to run the algorithm.
Basically, it initializes the :class:`~leaspy.variables.state.State` object,
updates it using the :meth:`~leaspy.algo.AbstractFitAlgo.iteration` method then returns it.
Parameters
----------
model : :class:`~leaspy.models.McmcSaemCompatibleModel`
The used model. It must be a subclass of :class:`~leaspy.models.McmcSaemCompatibleModel`.
dataset : :class:`~leaspy.io.data.Dataset`
Contains the subjects' observations in torch format to speed up computation.
Returns
-------
:class:`~leaspy.variables.state.State` :
The fitted state.
"""
with self._device_manager(model, dataset):
state = self._initialize_algo(model, dataset)
if self.algo_parameters["progress_bar"]:
self._display_progress_bar(
-1, self.algo_parameters["n_iter"], suffix="iterations"
)
for self.current_iteration in range(1, self.algo_parameters["n_iter"] + 1):
self._iteration(model, state)
if self.output_manager is not None:
# print/plot first & last iteration!
# <!> everything that will be printed/saved is AFTER iteration N
# (including temperature when annealing...)
self.output_manager.iteration(self, model, dataset)
if self.algo_parameters["progress_bar"]:
self._display_progress_bar(
self.current_iteration - 1,
self.algo_parameters["n_iter"],
suffix="iterations",
)
model.fit_metrics = self._get_fit_metrics()
model_state = state.clone()
with model_state.auto_fork(None):
# <!> At the end of the MCMC, population and individual latent variables
# may have diverged from final model parameters.
# Thus, we reset population latent variables to their mode
model_state.put_population_latent_variables(
LatentVariableInitType.PRIOR_MODE
)
model.state = model_state
return state
def _initialize_algo(
self,
model: McmcSaemCompatibleModel,
dataset: Dataset,
) -> State:
# TODO? mutualize with perso mcmc algo?
state = model.state
with state.auto_fork(None):
model.put_data_variables(state, dataset)
# Initialize individual latent variables (population ones should be initialized before)
model.put_individual_parameters(state, dataset)
self._initialize_samplers(state, dataset)
self._initialize_annealing()
return state
def _iteration(
self,
model: McmcSaemCompatibleModel,
state: State,
) -> None:
"""
MCMC-SAEM iteration.
1. Sample : MC sample successively of the population and individual variables
2. Maximization step : update model parameters from current population/individual variables values.
Parameters
----------
model : :class:`~leaspy.models.McmcSaemCompatibleModel`
state : :class:`~leaspy.variables.state.State`
"""
variables = sorted(
list(state.dag.sorted_variables_by_type[PopulationLatentVariable])
+ list(state.dag.sorted_variables_by_type[IndividualLatentVariable])
)
if self.random_order_variables:
shuffle(variables)
for variable in variables:
self.samplers[variable].sample(state, temperature_inv=self.temperature_inv)
self._maximization_step(model, state)
self._update_temperature()
def _maximization_step(self, model: McmcSaemCompatibleModel, state: State):
"""Maximization step as in the EM algorithm.
In practice parameters are set to current state (burn-in phase), or as a barycenter with previous state.
Parameters
----------
model : :class:`~leaspy.models.McmcSaemCompatibleModel`
state : :class:`~leaspy.variables.state.State`
"""
# TODO/WIP: not 100% clear to me whether model methods should take a state param, or always use its internal state...
sufficient_statistics = model.compute_sufficient_statistics(state)
if (
self._is_burn_in()
or self.current_iteration == 1 + self.algo_parameters["n_burn_in_iter"]
):
# the maximization step is memoryless (or first iteration with memory)
self.sufficient_statistics = sufficient_statistics
else:
burn_in_step = (
self.current_iteration - self.algo_parameters["n_burn_in_iter"]
) # min = 2, max = n_iter - n_burn_in_iter
burn_in_step **= -self.algo_parameters["burn_in_step_power"]
# this new formulation (instead of v + burn_in_step*(sufficient_statistics[k] - v))
# enables to keep `inf` deltas
self.sufficient_statistics = {
k: v * (1.0 - burn_in_step) + burn_in_step * sufficient_statistics[k]
for k, v in self.sufficient_statistics.items()
}
# TODO: use the same method in both cases (<!> very minor differences that might break
# exact reproducibility in tests)
model.update_parameters(
state, self.sufficient_statistics, burn_in=self._is_burn_in()
)
[docs]
def log_current_iteration(self, state: State):
if (
self.is_current_iteration_in_last_n()
or self.should_current_iteration_be_saved()
):
state.save(
self.logs.parameter_convergence_path,
iteration=self.current_iteration,
)
[docs]
def is_current_iteration_in_last_n(self):
"""Return True if current iteration is within the last n realizations defined in logging settings."""
return (
self.current_iteration
> self.algo_parameters["n_iter"] - self.logs.save_last_n_realizations
)
[docs]
def should_current_iteration_be_saved(self):
"""Return True if current iteration should be saved based on log saving periodicity."""
return (
self.logs.save_periodicity
and self.current_iteration % self.logs.save_periodicity == 0
)