Source code for leaspy.algo.fit.base
"""This module defines the `AbstractFitAlgo` class used for fitting algorithms."""
from typing import Optional
from leaspy.utils.typing import DictParamsTorch
from ..base import AlgorithmType, IterativeAlgorithm, ModelType, ReturnType
from ..settings import AlgorithmSettings, OutputsSettings
__all__ = ["FitAlgorithm"]
[docs]
class FitAlgorithm(IterativeAlgorithm[ModelType, ReturnType]):
r"""Abstract class containing common method for all `fit` algorithm classes.
The algorithm is proven to converge if the sequence `burn_in_step` is positive, with an
infinite sum :math:`\sum_k \epsilon_k = +\infty` and a finite sum of the squares
:math:`\sum_k \epsilon_k^2 < \infty` (see following paper).
`Construction of Bayesian Deformable Models via a Stochastic Approximation Algorithm: A Convergence Study <https://arxiv.org/abs/0706.0787>`_
Parameters
----------
settings : :class:`~leaspy.algo.AlgorithmSettings`
The specifications of the algorithm as a :class:`~leaspy.algo.AlgorithmSettings` instance.
Attributes
----------
algorithm_device : :obj:`str`
Valid :class:`torch.device`
current_iteration : :obj:`int`, default 0
The number of the current iteration.
The first iteration will be 1 and the last one `n_iter`.
sufficient_statistics : :obj:`dict` [:obj:`str`, :class:`torch.Tensor`] or None
Sufficient statistics of the previous step.
It is None during all the burn-in phase.
output_manager : :class:`~leaspy.io.logs.fit_output_manager.FitOutputManager`
Optional output manager of the algorithm
Inherited attributes
From :class:`~leaspy.algo.AbstractAlgo`
See Also
--------
:meth:`leaspy.api.Leaspy.fit`
"""
family = AlgorithmType.FIT
def __init__(self, settings: AlgorithmSettings):
super().__init__(settings)
self.logs = settings.logs
self.sufficient_statistics: Optional[DictParamsTorch] = None
[docs]
def set_output_manager(self, output_settings: OutputsSettings) -> None:
"""Set a :class:`~leaspy.algo.fit.FitOutputManager` object for the run of the algorithm.
Parameters
----------
output_settings : :class:`~leaspy.algo.OutputsSettings`
Contains the logs settings for the computation run (console print periodicity, plot periodicity ...)
Examples
--------
>>> from leaspy.algo import AlgorithmSettings, algorithm_factory, OutputsSettings
>>> algo_settings = AlgorithmSettings("mcmc_saem")
>>> my_algo = algorithm_factory(algo_settings)
>>> settings = {
'path': 'brouillons',
'print_periodicity': 50,
'plot_periodicity': 100,
'save_periodicity': 50
}
>>> my_algo.set_output_manager(OutputsSettings(settings))
"""
if output_settings is not None:
from .fit_output_manager import FitOutputManager
self.output_manager = FitOutputManager(output_settings)
def _get_fit_metrics(self) -> Optional[dict[str, float]]:
# TODO: finalize metrics handling, a bit dirty to place them in sufficient stats, only with a prefix...
if self.sufficient_statistics is None:
return None
return {
# (scalars only)
k: v.item()
for k, v in self.sufficient_statistics.items()
if k.startswith("nll_")
}
def __str__(self) -> str:
out = super().__str__()
# add the fit metrics after iteration number (included the sufficient statistics for now...)
fit_metrics = self._get_fit_metrics() or {}
if len(fit_metrics):
out += "\n= Metrics ="
for m, v in fit_metrics.items():
out += f"\n {m} : {v:.5g}"
return out