Source code for leaspy.models.shared_speed_logistic

import torch

from leaspy.io.data.dataset import Dataset
from leaspy.utils.functional import Exp, OrthoBasis
from leaspy.utils.weighted_tensor import (
    TensorOrWeightedTensor,
    WeightedTensor,
    unsqueeze_right,
)
from leaspy.variables.distributions import Normal
from leaspy.variables.specs import (
    Hyperparameter,
    LinkedVariable,
    ModelParameter,
    NamedVariables,
    PopulationLatentVariable,
    VariableNameToValueMapping,
)

from .logistic import LogisticInitializationMixin
from .time_reparametrized import TimeReparametrizedModel

__all__ = ["SharedSpeedLogisticModel"]


[docs] class SharedSpeedLogisticModel(LogisticInitializationMixin, TimeReparametrizedModel): """ Logistic model for multiple variables of interest, imposing same average evolution pace for all variables (logistic curves are only time-shifted). Parameters ---------- name : :obj:`str` The name of the model. **kwargs Hyperparameters of the model. """ def __init__(self, name: str, **kwargs): super().__init__(name, **kwargs) def _compute_initial_values_for_model_parameters( self, dataset: Dataset, ) -> VariableNameToValueMapping: parameters = super()._compute_initial_values_for_model_parameters(dataset) parameters["log_g_mean"] = parameters["log_g_mean"].mean() parameters["xi_mean"] = parameters["log_v0_mean"].mean() del parameters["log_v0_mean"] parameters["deltas_mean"] = torch.zeros((self.dimension - 1,)) return parameters
[docs] @staticmethod def metric(*, g_deltas_exp: torch.Tensor) -> torch.Tensor: """ Compute the metric term for the logistic model. This scaling term modulates the curvature of the logistic trajectory. Parameters ---------- g_deltas_exp : :class:`torch.Tensor` Product of slope and exp(-deltas). Returns ------- :class:`torch.Tensor` Metric value, computed as: .. math:: \\frac{(g \\cdot e^{-\\delta} + 1)^2}{g \\cdot e^{-\\delta}} """ return (g_deltas_exp + 1) ** 2 / g_deltas_exp
[docs] @staticmethod def deltas_exp(*, deltas_padded: torch.Tensor) -> torch.Tensor: """Compute the exponential of the negative deltas. Parameters: deltas_padded (:class:`torch.Tensor`): Padded deltas. Returns: :class:`torch.Tensor`: Exponential of the negative deltas. """ return torch.exp(-1 * deltas_padded)
[docs] @staticmethod def g_deltas_exp(*, g: torch.Tensor, deltas_exp: torch.Tensor) -> torch.Tensor: return g * deltas_exp
[docs] @staticmethod def pad_deltas(*, deltas: torch.Tensor) -> torch.Tensor: """Prepend deltas with a zero as delta_1 is set to zero in the equations. . Parameters: deltas (:class:`torch.Tensor`): Deltas tensor. Returns: :class:`torch.Tensor`: Padded deltas tensor. """ return torch.cat((torch.tensor([0.0]), deltas))
[docs] @staticmethod def denom(*, g_deltas_exp: torch.Tensor) -> torch.Tensor: r""" Compute the denominator for the :math:`\gamma_{t_0}` calculation. """ return 1 + g_deltas_exp
[docs] @staticmethod def gamma_t0(*, denom: torch.Tensor) -> torch.Tensor: r""" Compute the :math:`\gamma_{t_0}` value, which is the inverse of the denominator. Parameters ---------- denom : :class:`torch.Tensor` Denominator term. Returns ------- :class:`torch.Tensor` :math:`\gamma_{t_0}` value. """ return 1 / denom
[docs] @staticmethod def g_metric(*, gamma_t0: torch.Tensor) -> torch.Tensor: """ Compute the g_metric value, which is the square of gamma_t0. Parameters ---------- gamma_t0 : :class:`torch.Tensor` :math:`\\gamma_{t_0}` value. Returns ------- :class:`torch.Tensor` :math:`g_{metric}` value, computed as: .. math:: g\\_metric = \\frac{1}{(\\gamma_{t0} \\cdot (1 - \\gamma_{t0}))^2} """ return 1 / (gamma_t0 * (1 - gamma_t0)) ** 2
[docs] @staticmethod def collin_to_d_gamma_t0( *, deltas_exp: torch.Tensor, denom: torch.Tensor ) -> torch.Tensor: """Compute the collinear term to d_gamma_t0. Parameters ---------- deltas_exp : :class:`torch.Tensor` Exponential of the negative deltas. denom : :class:`torch.Tensor` Denominator term. Returns ------- :class:`torch.Tensor` Collinear term to d_gamma_t0 """ return deltas_exp / denom**2
[docs] @classmethod def model_with_sources( cls, *, rt: TensorOrWeightedTensor[float], space_shifts: TensorOrWeightedTensor[float], metric: TensorOrWeightedTensor[float], deltas_padded: TensorOrWeightedTensor[float], log_g: TensorOrWeightedTensor[float], ) -> torch.Tensor: """Returns a model with sources.""" # Shape: (Ni, Nt, Nfts) pop_s = (None, None, ...) rt = unsqueeze_right(rt, ndim=1) # .filled(float('nan')) w_model_logit = ( metric[pop_s] * space_shifts[:, None, ...] + rt + deltas_padded - log_g[pop_s] ) model_logit, weights = WeightedTensor.get_filled_value_and_weight( w_model_logit, fill_value=0.0 ) return WeightedTensor(torch.sigmoid(model_logit), weights).weighted_value
[docs] @classmethod def model_no_sources( cls, *, rt: TensorOrWeightedTensor[float], metric: TensorOrWeightedTensor[float], deltas_padded: TensorOrWeightedTensor[float], log_g: TensorOrWeightedTensor[float], ) -> torch.Tensor: """Returns a model without sources.""" return cls.model_with_sources( rt=rt, metric=metric, deltas_padded=deltas_padded, log_g=log_g, space_shifts=torch.zeros((1, 1)), )
[docs] def get_variables_specs(self) -> NamedVariables: """ Get the specifications of the variables used in the model. """ d = super().get_variables_specs() d.update( log_g_mean=ModelParameter.for_pop_mean("log_g", shape=(1,)), log_g_std=Hyperparameter(0.01), log_g=PopulationLatentVariable(Normal("log_g_mean", "log_g_std")), g=LinkedVariable(Exp("log_g")), xi_mean=ModelParameter.for_ind_mean("xi", shape=(1,)), deltas_mean=ModelParameter.for_pop_mean( "deltas", shape=(self.dimension - 1,), ), deltas_std=Hyperparameter(0.01), deltas=PopulationLatentVariable( Normal("deltas_mean", "deltas_std"), sampling_kws={"scale": 1.0}, ), deltas_padded=LinkedVariable(self.pad_deltas), deltas_exp=LinkedVariable(self.deltas_exp), g_deltas_exp=LinkedVariable(self.g_deltas_exp), metric=LinkedVariable(self.metric), ) if self.source_dimension >= 1: d.update( denom=LinkedVariable(self.denom), gamma_t0=LinkedVariable(self.gamma_t0), collin_to_d_gamma_t0=LinkedVariable(self.collin_to_d_gamma_t0), g_metric=LinkedVariable(self.g_metric), orthonormal_basis=LinkedVariable( OrthoBasis("collin_to_d_gamma_t0", "g_metric"), ), model=LinkedVariable(self.model_with_sources), ) else: d["model"] = LinkedVariable(self.model_no_sources) return d