Source code for leaspy.models.riemanian_manifold

from abc import abstractmethod
from typing import Iterable, Optional

import torch

from leaspy.utils.functional import Exp, OrthoBasis, Sqr
from leaspy.variables.distributions import Normal
from leaspy.variables.specs import (
    Hyperparameter,
    LinkedVariable,
    ModelParameter,
    NamedVariables,
    PopulationLatentVariable,
    SuffStatsRW,
    VariableName,
)
from leaspy.variables.state import State

from .time_reparametrized import TimeReparametrizedModel

# TODO refact? implement a single function
# compute_individual_tensorized(..., with_jacobian: bool) -> returning either
# model values or model values + jacobians wrt individual parameters

# TODO refact? subclass or other proper code technique to extract model's concrete
#  formulation depending on if linear, logistic, mixed log-lin, ...


__all__ = [
    "RiemanianManifoldModel",
    "LinearInitializationMixin",
    "LinearModel",
]


[docs] class RiemanianManifoldModel(TimeReparametrizedModel): """Manifold model for multiple variables of interest (logistic or linear formulation). Parameters ---------- name : :obj:`str` The name of the model. **kwargs Hyperparameters of the model (including `noise_model`) Raises ------ :exc:`.LeaspyModelInputError` * If hyperparameters are inconsistent """ def __init__( self, name: str, variables_to_track: Optional[Iterable[VariableName]] = None, **kwargs, ): super().__init__(name, **kwargs) default_variables_to_track = [ "g", "v0", "noise_std", "tau_mean", "tau_std", "xi_mean", "xi_std", "nll_attach", "nll_regul_log_g", "nll_regul_log_v0", "xi", "tau", "nll_regul_pop_sum", "nll_regul_all_sum", "nll_tot", ] if self.source_dimension: default_variables_to_track += [ "sources", "betas", "mixing_matrix", "space_shifts", ] self.track_variables(variables_to_track or default_variables_to_track) @classmethod def _center_xi_realizations(cls, state: State) -> None: """ Center the ``xi`` realizations in place. Parameters ---------- state : :class:`.State` The dictionary-like object representing current model state, which contains keys such as``"xi"`` and ``"log_v0"``. Notes ----- This transformation preserves the orthonormal basis since the new ``v0`` remains collinear to the previous one. It is a purely internal operation meant to reduce redundancy in the parameter space (i.e., improve identifiability and stabilize inference). """ mean_xi = torch.mean(state["xi"]) state["xi"] = state["xi"] - mean_xi state["log_v0"] = state["log_v0"] + mean_xi # TODO: find a way to prevent re-computation of orthonormal basis since it should # not have changed (v0_collinear update) # self.update_MCMC_toolbox({'v0_collinear'}, realizations)
[docs] @classmethod def compute_sufficient_statistics(cls, state: State) -> SuffStatsRW: """ Compute the model's :term:`sufficient statistics`. Parameters ---------- state : :class:`.State` The state to pick values from. Returns ------- SuffStatsRW : The computed sufficient statistics. """ # <!> modify 'xi' and 'log_v0' realizations in-place # TODO: what theoretical guarantees for this custom operation? cls._center_xi_realizations(state) return super().compute_sufficient_statistics(state)
[docs] def get_variables_specs(self) -> NamedVariables: """ Return the specifications of the variables (latent variables, derived variables, model 'parameters') that are part of the model. Returns ------- NamedVariables : A dictionary-like object mapping variable names to their specifications. These include `ModelParameter`, `Hyperparameter`, `PopulationLatentVariable`, and `LinkedVariable` instances. """ d = super().get_variables_specs() d.update( # PRIORS log_v0_mean=ModelParameter.for_pop_mean( "log_v0", shape=(self.dimension,), ), log_v0_std=Hyperparameter(0.01), xi_mean=Hyperparameter(0.0), # LATENT VARS log_v0=PopulationLatentVariable( Normal("log_v0_mean", "log_v0_std"), ), # DERIVED VARS v0=LinkedVariable( Exp("log_v0"), ), metric=LinkedVariable( self.metric ), # for linear model: metric & metric_sqr are fixed = 1. ) if self.source_dimension >= 1: d.update( model=LinkedVariable(self.model_with_sources), metric_sqr=LinkedVariable(Sqr("metric")), orthonormal_basis=LinkedVariable(OrthoBasis("v0", "metric_sqr")), ) else: d["model"] = LinkedVariable(self.model_no_sources) # TODO: WIP # variables_info.update(self.get_additional_ordinal_population_random_variable_information()) # self.update_ordinal_population_random_variable_information(variables_info) return d
[docs] @staticmethod @abstractmethod def metric(*, g: torch.Tensor) -> torch.Tensor: pass
[docs] @classmethod def model_no_sources(cls, *, rt: torch.Tensor, metric, v0, g) -> torch.Tensor: """ Return the model output when sources(spatial components) are not present. Parameters ---------- rt : :class:`torch.Tensor` The reparametrized time. metric : Any The metric tensor used for computing the spatial/temporal influence. v0 : Any The values of the population parameter `v0` for each feature. g : Any The values of the population parameter `g` for each feature. Returns ------- :class:`torch.Tensor` The model output without contribution from source shifts. Notes ----- This implementation delegates to `model_with_sources` with `space_shifts` set to a zero tensor of shape (1, 1), effectively removing source effects. """ return cls.model_with_sources( rt=rt, metric=metric, v0=v0, g=g, space_shifts=torch.zeros((1, 1)), )
[docs] @classmethod @abstractmethod def model_with_sources( cls, *, rt: torch.Tensor, space_shifts: torch.Tensor, metric, v0, g, ) -> torch.Tensor: pass