Source code for leaspy.variables.utilities

import torch

__all__ = [
    "compute_individual_parameter_std_from_sufficient_statistics",
]


[docs] def compute_individual_parameter_std_from_sufficient_statistics( state: dict[str, torch.Tensor], individual_parameter_values: torch.Tensor, individual_parameter_sqr_values: torch.Tensor, *, individual_parameter_name: str, dim: int, **kws, ): """ Maximization rule, from the sufficient statistics, of the standard-deviation of Gaussian prior for individual latent variables. Parameters ---------- state : :obj:`dict`[:obj:`str`, :class:`torch.Tensor`] The current state object that holds all the variables individual_parameter_values : :class:`torch.Tensor` Tensor containing individual parameter values, used to compute current means. individual_parameter_sqr_values : :class:`torch.Tensor` Tensor containing squared individual parameter values, used to compute variances. individual_parameter_name : :obj:`str` The name of the individual parameter for which to compute the std. dim : :obj:`int` The dimension along which to compute the mean and variance Returns ------- :class:`torch.Tensor` The updated standard deviation of the Gaussian prior for the individual parameter """ from leaspy.models.utilities import compute_std_from_variance individual_parameter_old_mean = state[f"{individual_parameter_name}_mean"] individual_parameter_current_mean = torch.mean(individual_parameter_values, dim=dim) individual_parameter_variance_update = ( torch.mean(individual_parameter_sqr_values, dim=dim) - 2 * individual_parameter_old_mean * individual_parameter_current_mean ) individual_parameter_variance = ( individual_parameter_variance_update + individual_parameter_old_mean**2 ) return compute_std_from_variance( individual_parameter_variance, varname=f"{individual_parameter_name}_std", **kws )