leaspy.algo.fit.base

This module defines the AbstractFitAlgo class used for fitting algorithms.

Classes

FitAlgorithm

Abstract class containing common method for all fit algorithm classes.

Module Contents

class FitAlgorithm(settings)[source]

Bases: leaspy.algo.base.IterativeAlgorithm[leaspy.algo.base.ModelType, leaspy.algo.base.ReturnType]

Abstract class containing common method for all fit algorithm classes.

The algorithm is proven to converge if the sequence burn_in_step is positive, with an infinite sum \(\sum_k \epsilon_k = +\infty\) and a finite sum of the squares \(\sum_k \epsilon_k^2 < \infty\) (see following paper).

Construction of Bayesian Deformable Models via a Stochastic Approximation Algorithm: A Convergence Study

Parameters:
settingsAlgorithmSettings

The specifications of the algorithm as a AlgorithmSettings instance.

Attributes:
algorithm_devicestr

Valid torch.device

current_iterationint, default 0

The number of the current iteration. The first iteration will be 1 and the last one n_iter.

sufficient_statisticsdict [str, torch.Tensor] or None

Sufficient statistics of the previous step. It is None during all the burn-in phase.

output_managerFitOutputManager

Optional output manager of the algorithm

Inherited attributes

From AbstractAlgo

Parameters:

settings (AlgorithmSettings)

See also

leaspy.api.Leaspy.fit()
family
logs
sufficient_statistics: DictParamsTorch | None = None
set_output_manager(output_settings)[source]

Set a FitOutputManager object for the run of the algorithm.

Parameters:
output_settingsOutputsSettings

Contains the logs settings for the computation run (console print periodicity, plot periodicity …)

Parameters:

output_settings (OutputsSettings)

Return type:

None

Examples

>>> from leaspy.algo import AlgorithmSettings, algorithm_factory, OutputsSettings
>>> algo_settings = AlgorithmSettings("mcmc_saem")
>>> my_algo = algorithm_factory(algo_settings)
>>> settings = {
    'path': 'brouillons',
    'print_periodicity': 50,
    'plot_periodicity': 100,
    'save_periodicity': 50
}
>>> my_algo.set_output_manager(OutputsSettings(settings))