Source code for leaspy.models.mcmc_saem_compatible

from abc import abstractmethod
from typing import Iterable, Optional, Union

import torch

from leaspy.exceptions import LeaspyIndividualParamsInputError, LeaspyModelInputError
from leaspy.io.data.dataset import Dataset
from leaspy.utils.typing import DictParams, DictParamsTorch, KwargsType
from leaspy.utils.weighted_tensor import TensorOrWeightedTensor, WeightedTensor
from leaspy.variables.specs import (
    LVL_FT,
    DataVariable,
    LatentVariableInitType,
    ModelParameter,
    NamedVariables,
    SuffStatsRO,
    SuffStatsRW,
)
from leaspy.variables.state import State

from .obs_models import ObservationModel
from .stateful import StatefulModel

__all__ = ["McmcSaemCompatibleModel"]


[docs] class McmcSaemCompatibleModel(StatefulModel): """Defines probabilistic models compatible with an MCMC SAEM estimation. Parameters ---------- name : :obj:`str` The name of the model. obs_models : :class:`~leaspy.models.obs_models` or :class:`~typing.Iterable` [:class:`~leaspy.models.obs_models`] The noise model for observations (keyword-only parameter). fit_metrics : :obj:`dict` Metrics that should be measured during the fit of the model and reported back to the user. **kwargs Hyperparameters for the model Attributes ---------- is_initialized : :obj:`bool` Indicates if the model is initialized. name : :obj:`str` The model's name. features : :obj:`list` [:obj:`str`] Names of the model features. parameters : :obj:`dict` Contains the model's parameters obs_models : :obj:`tuple` [:class:`~leaspy.models.obs_models`, ...] The observation model(s) associated to the model. fit_metrics : :obj:`dict` Contains the metrics that are measured during the fit of the model and reported to the user. _state : :class:`~leaspy.variables.state.State` Private instance holding all values for model variables and their derived variables. """ # Base parameter categories for summary display (override in subclasses) _individual_prior_params: tuple[str, ...] = ( "tau_mean", "tau_std", "xi_mean", "xi_std", "sources_mean", "sources_std", "zeta_mean" ) _noise_params: tuple[str, ...] = ("noise_std",) # Explicit axis labels for multi-dimensional parameters # Maps param_name -> tuple of axis names, e.g., ("feature",) or ("feature", "source") # Subclasses can extend this with: _param_axes = {**ParentClass._param_axes, "new_param": ("axis",)} _param_axes: dict[str, tuple[str, ...]] = { "log_g_mean": ("feature",), "log_g_std": ("feature",), "log_v0_mean": ("feature",), "betas_mean": ("basis", "source"), # basis vectors, not features (dim-1) "mixing_matrix": ("source", "feature"), "noise_std": ("feature",), } @property def _param_categories(self) -> dict[str, list[str]]: """Categorize parameters for summary display.""" ind_priors = set(self._individual_prior_params) noise = set(self._noise_params) all_params = set(self.parameters.keys()) if self.parameters else set() pop = all_params - ind_priors - noise def sort_key(name: str) -> tuple[int, str, str]: # Sort by number of columns (ascending), then primary axis, then name val = self.parameters[name] axes = self._param_axes.get(name, ()) primary_axis = axes[0] if axes else "" n_cols = 1 if val.ndim == 1 and axes: # Check if this axis produces labeled columns if self._get_axis_labels(primary_axis, len(val)) is not None: n_cols = len(val) elif val.ndim == 2: n_cols = val.shape[1] return (n_cols, primary_axis, name) return { "population": sorted((k for k in pop if k in all_params), key=sort_key), "individual_priors": sorted((k for k in ind_priors if k in all_params), key=sort_key), "noise": sorted((k for k in noise if k in all_params), key=sort_key), } def __init__( self, name: str, *, # TODO? if we'd allow to pass a state there should be a all bunch of checks I guess? only "equality" of DAG is OK? # (WIP: cf. comment regarding inclusion of state here) # state: Optional[State] = None, # TODO? Factory of `ObservationModel` instead? (typically one would need the dimension to instantiate the `noise_std` variable of the right shape...) obs_models: Union[ObservationModel, Iterable[ObservationModel]], fit_metrics: Optional[dict[str, float]] = None, **kwargs, ): super().__init__(name, **kwargs) if isinstance(obs_models, ObservationModel): obs_models = (obs_models,) self.obs_models = tuple(obs_models) # load hyperparameters # <!> some may still be missing at this point (e.g. `dimension`, `source_dimension`, ...) # (thus we sh/could NOT instantiate the DAG right now!) self._load_hyperparameters(kwargs) # TODO: dirty hack for now, cf. AbstractFitAlgo self.fit_metrics = fit_metrics @property def observation_model_names(self) -> list[str]: """Get the names of the observation models. Returns ------- :obj:`list` [:obj:`str`] : The names of the observation models. """ return [model.to_string() for model in self.obs_models]
[docs] def has_observation_model_with_name(self, name: str) -> bool: """ Check if the model has an observation model with the given name. Parameters ---------- name : :obj:`str` The name of the observation model to check. Returns ------- :obj:`bool`: True if the model has an observation model with the given name, False otherwise. """ return name in self.observation_model_names
[docs] def to_dict(self, **kwargs) -> KwargsType: """Export model as a dictionary ready for export. Returns ------- :class:`~leaspy.utils.typing.KwargsType` The model instance serialized as a dictionary. """ d = super().to_dict() d.update( { "obs_models": { obs_model.name: obs_model.to_string() for obs_model in self.obs_models }, "fit_metrics": self.fit_metrics, # TODO improve } ) return d
@abstractmethod def _load_hyperparameters(self, hyperparameters: KwargsType) -> None: """Load model's hyperparameters. Parameters ---------- hyperparameters : :obj:`dict` [ :obj:`str`, :obj:`Any` ] Contains the model's hyperparameters. """ @classmethod def _raise_if_unknown_hyperparameters( cls, known_hps: Iterable[str], given_hps: KwargsType ) -> None: """Check if the given hyperparameters are known for the model. Parameters ---------- known_hps : :obj:`Iterable` [:obj:`str`] The known hyperparameters for the model. given_hps : :class:`~leaspy.utils.typing.KwargsType` The hyperparameters provided to the model. Raises ------ :exc:`.LeaspyModelInputError` If any unknown hyperparameter is provided to the model. """ # TODO: replace with better logic from GenericModel in the future unexpected_hyperparameters = set(given_hps.keys()).difference(known_hps) if len(unexpected_hyperparameters) > 0: raise LeaspyModelInputError( f"Only {known_hps} are valid hyperparameters for {cls.__qualname__}. " f"Unknown hyperparameters provided: {unexpected_hyperparameters}." ) @abstractmethod def _audit_individual_parameters( self, individual_parameters: DictParams ) -> KwargsType: """ Perform various consistency and compatibility (with current model) checks on an individual parameters dict and outputs qualified information about it. TODO? move to IndividualParameters class? Parameters ---------- individual_parameters : :class:`~leaspy.utils.typing.DictParams` Contains some individual parameters. If representing only one individual (in a multivariate model) it could be: * {'tau':0.1, 'xi':-0.3, 'sources':[0.1,...]} Or for multiple individuals: * {'tau':[0.1,0.2,...], 'xi':[-0.3,0.2,...], 'sources':[[0.1,...],[0,...],...]} In particular, a sources vector (if present) should always be a array_like, even if it is 1D Returns ------- ips_info : :class:`~leaspy.utils.typing.KwargsType` * ``'nb_inds'`` : :obj:`int` >= 0 Number of individuals present. * ``'tensorized_ips'`` : :obj:`dict` [ :obj:`str`, :class:`torch.Tensor` ] Tensorized version of individual parameters. * ``'tensorized_ips_gen'`` : generator Generator providing tensorized individual parameters for all individuals present (ordered as is). Raises ------ :exc:`.NotImplementedError` """ raise NotImplementedError def _get_tensorized_inputs( self, timepoints: torch.Tensor, individual_parameters: DictParamsTorch, *, skip_ips_checks: bool = False, ) -> tuple[torch.Tensor, DictParamsTorch]: """Convert the timepoints and individual parameters to tensors. Parameters ---------- timepoints : :obj:`torch.Tensor` Contains the timepoints (age(s) of the subject). individual_parameters : :class:`~leaspy.utils.typing.DictParamsTorch` Contains the individual parameters. skip_ips_checks : :obj:`bool` (default: ``False``) Flag to skip consistency/compatibility checks and tensorization of ``individual_parameters`` when it was done earlier (speed-up). Returns ------- :obj:`tuple` [:class:`torch.Tensor`, :class:`~leaspy.utils.typing.DictParamsTorch`] The timepoints and individual parameters converted to tensors. Raises ------ :exc:`.LeaspyModelInputError` If computation is tried on more than 1 individual. """ from .utilities import tensorize_2D if not skip_ips_checks: individual_parameters_info = self._audit_individual_parameters( individual_parameters ) individual_parameters = individual_parameters_info["tensorized_ips"] if (n_individual_parameters := individual_parameters_info["nb_inds"]) != 1: raise LeaspyModelInputError( "Only one individual computation may be performed at a time. " f"{n_individual_parameters} was provided." ) # Convert the timepoints (list of numbers, or single number) to a 2D torch tensor timepoints = tensorize_2D(timepoints, unsqueeze_dim=0) # 1 individual return timepoints, individual_parameters def _check_individual_parameters_provided( self, individual_parameters_keys: Iterable[str] ) -> None: """Check consistency of individual parameters keys provided. Parameters ---------- individual_parameters_keys : :obj:`Iterable` [:obj:`str`] The keys of the individual parameters provided. Raises ------ :exc:`.LeaspyIndividualParamsInputError` If any of the individual parameters keys are unknown or missing. """ ind_vars = set(self.individual_variables_names) unknown_ips = set(individual_parameters_keys).difference(ind_vars) missing_ips = ind_vars.difference(individual_parameters_keys) errs = [] if len(unknown_ips): errs.append(f"Unknown individual latent variables: {unknown_ips}") if len(missing_ips): errs.append(f"Missing individual latent variables: {missing_ips}") if len(errs): raise LeaspyIndividualParamsInputError(". ".join(errs))
[docs] def compute_individual_trajectory( self, timepoints: list[float], individual_parameters: DictParams, *, skip_ips_checks: bool = False, ) -> torch.Tensor: """Compute scores values at the given time-point(s) given a subject's individual parameters. .. note:: The model uses its current internal state. Parameters ---------- timepoints : :obj:`scalar` or :obj:`array_like` [:obj:`scalar`] (:obj:`list`, :obj:`tuple`, :class:`numpy.ndarray`) Contains the age(s) of the subject. individual_parameters : :class:`~leaspy.utils.typing.DictParams` Contains the individual parameters. Each individual parameter should be a scalar or array_like. skip_ips_checks : :obj:`bool` (default: ``False``) Flag to skip consistency/compatibility checks and tensorization of ``individual_parameters`` when it was done earlier (speed-up). Returns ------- :class:`torch.Tensor` Contains the subject's scores computed at the given age(s) Shape of tensor is ``(1, n_tpts, n_features)``. """ self._check_individual_parameters_provided(individual_parameters.keys()) timepoints, individual_parameters = self._get_tensorized_inputs( timepoints, individual_parameters, skip_ips_checks=skip_ips_checks ) # TODO? ability to revert back after **several** assignments? # instead of cloning the state for this op? local_state = self.state.clone(disable_auto_fork=True) self._put_data_timepoints(local_state, timepoints) for ( individual_parameter_name, individual_parameter_value, ) in individual_parameters.items(): local_state[individual_parameter_name] = individual_parameter_value return local_state["model"]
[docs] def compute_prior_trajectory( self, timepoints: torch.Tensor, prior_type: LatentVariableInitType, *, n_individuals: Optional[int] = None, ) -> TensorOrWeightedTensor[float]: """ Compute trajectory of the model for prior mode or mean of individual parameters. Parameters ---------- timepoints : :obj:`torch.Tensor` [1, n_timepoints] Contains the timepoints (age(s) of the subject). prior_type : :class:`~leaspy.variables.specs.LatentVariableInitType` The type of prior to use for the individual parameters. n_individuals : :obj:`int`, optional The number of individuals. Returns ------- :class:`torch.Tensor` [1, n_timepoints, dimension] The group-average values at given timepoints. Raises ------ :exc:`.LeaspyModelInputError` If `n_individuals` is provided but not a positive integer, or if it is provided while `prior_type` is not `LatentVariableInitType.PRIOR_SAMPLES`. """ exc_n_ind_iff_prior_samples = LeaspyModelInputError( "You should provide n_individuals (int >= 1) if, " "and only if, prior_type is `PRIOR_SAMPLES`" ) if n_individuals is None: if prior_type is LatentVariableInitType.PRIOR_SAMPLES: raise exc_n_ind_iff_prior_samples n_individuals = 1 elif prior_type is not LatentVariableInitType.PRIOR_SAMPLES or not ( isinstance(n_individuals, int) and n_individuals >= 1 ): raise exc_n_ind_iff_prior_samples local_state = self.state.clone(disable_auto_fork=True) self._put_data_timepoints(local_state, timepoints) local_state.put_individual_latent_variables( prior_type, n_individuals=n_individuals ) return local_state["model"]
[docs] def compute_mean_traj( self, timepoints: torch.Tensor ) -> TensorOrWeightedTensor[float]: """Trajectory for average of individual parameters (not really meaningful for non-linear models). Parameters ---------- timepoints : :obj:`torch.Tensor` [1, n_timepoints] Returns ------- :class:`torch.Tensor` [1, n_timepoints, dimension] The group-average values at given timepoints. """ # TODO/WIP: keep this in BaseModel interface? or only provide `compute_prior_trajectory`, or `compute_mode|typical_traj` instead? return self.compute_prior_trajectory( timepoints, LatentVariableInitType.PRIOR_MEAN )
[docs] def compute_mode_traj( self, timepoints: torch.Tensor ) -> TensorOrWeightedTensor[float]: """Most typical individual trajectory. Parameters ---------- timepoints : :obj:`torch.Tensor` [1, n_timepoints] Returns ------- :class:`torch.Tensor` [1, n_timepoints, dimension] The group-average values at given timepoints. """ return self.compute_prior_trajectory( timepoints, LatentVariableInitType.PRIOR_MODE )
[docs] def compute_jacobian_tensorized( self, state: State, ) -> DictParamsTorch: """ Compute the jacobian of the model w.r.t. each individual parameter, given the input state. This function aims to be used in :class:`.ScipyMinimize` to speed up optimization. Parameters ---------- state : :class:`~leaspy.variables.state.State` Instance holding values for all model variables (including latent individual variables) Returns ------- :class:`~leaspy.utils.typing.DictParamsTorch` Tensors are of shape ``(n_individuals, n_timepoints, n_features, n_dims_param)``. Raises ------ :exc:`NotImplementedError` """ # return { # ip: state[f"model_jacobian_{ip}"] # for ip in self.get_individual_variable_names() # } raise NotImplementedError("This method is currently not implemented.")
[docs] @classmethod def compute_sufficient_statistics(cls, state: State) -> SuffStatsRW: """ Compute sufficient statistics from state. Parameters ---------- state : :class:`~leaspy.variables.state.State` Returns ------- :class:`~leaspy.variables.specs.SuffStatsRW` Contains the sufficient statistics computed from the state. """ suff_stats = {} for mp_var in state.dag.sorted_variables_by_type[ModelParameter].values(): mp_var: ModelParameter # type-hint only suff_stats.update(mp_var.suff_stats(state)) # we add some fake sufficient statistics that are in fact convergence metrics (summed over individuals) # TODO proper handling of metrics # We do not account for regularization of pop. vars since we do NOT have true Bayesian priors on them (for now) for k in ("nll_attach", "nll_regul_ind_sum"): suff_stats[k] = state[k] suff_stats["nll_tot"] = ( suff_stats["nll_attach"] + suff_stats["nll_regul_ind_sum"] ) # "nll_regul_all_sum" return suff_stats
[docs] @classmethod def update_parameters( cls, state: State, sufficient_statistics: SuffStatsRO, *, burn_in: bool, ) -> None: """ Update model parameters of the provided state. Parameters ---------- cls : :class:`~leaspy.models.mcmc_saem_compatible.McmcSaemCompatibleModel` Model class to which the parameters belong. state : :class:`~leaspy.variables.state.State` Instance holding values for all model variables (including latent individual variables) sufficient_statistics : :class:`~leaspy.variables.specs.SuffStatsRO` Contains the sufficient statistics computed from the state. burn_in : :obj:`bool` If True, the parameters are updated in a burn-in phase. """ # <!> we should wait before updating state since some updating rules may depending on OLD state # (i.e. no sequential update of state but batched updates once all updated values were retrieved) # (+ it would be inefficient since we could recompute some derived values between updates!) params_updates = {} for mp_name, mp_var in state.dag.sorted_variables_by_type[ ModelParameter ].items(): mp_var: ModelParameter # type-hint only params_updates[mp_name] = mp_var.compute_update( state=state, suff_stats=sufficient_statistics, burn_in=burn_in ) # mass update at end for mp, mp_updated_val in params_updates.items(): state[mp] = mp_updated_val
[docs] def get_variables_specs(self) -> NamedVariables: """Get the specifications of the variables used in the model. Returns ------- :class:`~leaspy.variables.specs.NamedVariables` Specifications of the variables used in the model, including timepoints and observation models. """ specifications = NamedVariables({"t": DataVariable()}) single_obs_model = len(self.obs_models) == 1 for obs_model in self.obs_models: specifications.update( obs_model.get_variables_specs(named_attach_vars=not single_obs_model) ) return specifications
[docs] @abstractmethod def put_individual_parameters(self, state: State, dataset: Dataset): """Put the individual parameters inside the provided state (in-place). Raises ------ :exc:`NotImplementedError` """ raise NotImplementedError()
def _put_data_timepoints( self, state: State, timepoints: TensorOrWeightedTensor[float] ) -> None: """Put the timepoints variables inside the provided state (in-place). Parameters ---------- state : :class:`~leaspy.variables.state.State` Instance holding values for all model variables (including latent individual variables), as well as: - timepoints : :class:`torch.Tensor` of shape (n_individuals, n_timepoints) timepoints : :class:`~leaspy.utils.weighted_tensor.WeightedTensor` or :class:`torch.Tensor` Contains the timepoints (age(s) of the subject). Raises ------ :exc:`TypeError` If the provided timepoints are not of type :class:`torch.Tensor` or :class:`~leaspy.utils.weighted_tensor.WeightedTensor`. """ # TODO/WIP: we use a regular tensor with 0 for times so that 'model' is a regular tensor # (to avoid having to cope with `StatelessDistributionFamily` having some `WeightedTensor` as parameters) if isinstance(timepoints, WeightedTensor): state["t"] = timepoints elif isinstance(timepoints, torch.Tensor): state["t"] = WeightedTensor(timepoints) else: raise TypeError( f"Time points should be either torch Tensors or WeightedTensors. " f"Instead, a {type(timepoints)} was provided." )
[docs] def put_data_variables(self, state: State, dataset: Dataset) -> None: """Put all the needed data variables inside the provided state (in-place). Parameters ---------- state : :class:`~leaspy.variables.state.State` Instance holding values for all model variables (including latent individual variables), as well as: - timepoints : :class:`torch.Tensor` of shape (n_individuals, n_timepoints) dataset : :class:`~leaspy.io.data.dataset.Dataset` The dataset containing the data to be put in the state. """ self._put_data_timepoints( state, WeightedTensor( dataset.timepoints, dataset.mask.to(torch.bool).any(dim=LVL_FT) ), ) for obs_model in self.obs_models: state[obs_model.name] = obs_model.getter(dataset)
[docs] def reset_data_variables(self, state: State) -> None: """Reset all data variables inside the provided state (in-place). Parameters ---------- state : :class:`~leaspy.variables.state.State` Instance holding values for all model variables (including latent individual variables), as well as: - timepoints : :class:`torch.Tensor` of shape (n_individuals, n_timepoints) """ state["t"] = None for obs_model in self.obs_models: state[obs_model.name] = None