Source code for leaspy.models.utils.attributes.logistic_attributes
import torch
from .abstract_manifold_model_attributes import AbstractManifoldModelAttributes
__all__ = ["LogisticAttributes"]
[docs]
class LogisticAttributes(AbstractManifoldModelAttributes):
"""
Attributes of leaspy logistic models.
Contains the common attributes & methods to update the logistic model's attributes.
Parameters
----------
name : str
dimension : int
source_dimension : int
Attributes
----------
name : str (default 'logistic')
Name of the associated leaspy model.
dimension : int
source_dimension : int
univariate : bool
Whether model is univariate or not (i.e. dimension == 1)
has_sources : bool
Whether model has sources or not (not univariate and source_dimension >= 1)
update_possibilities : set[str] (default {'all', 'g', 'v0', 'betas'} )
Contains the available parameters to update. Different models have different parameters.
positions : :class:`torch.Tensor` [dimension] (default None)
positions = exp(realizations['g']) such that "p0" = 1 / (1 + positions)
velocities : :class:`torch.Tensor` [dimension] (default None)
Always positive: exp(realizations['v0'])
orthonormal_basis : :class:`torch.Tensor` [dimension, dimension - 1] (default None)
betas : :class:`torch.Tensor` [dimension - 1, source_dimension] (default None)
mixing_matrix : :class:`torch.Tensor` [dimension, source_dimension] (default None)
Matrix A such that w_i = A * s_i.
See Also
--------
:class:`~leaspy.models.univariate_model.UnivariateModel`
:class:`~leaspy.models.multivariate_model.MultivariateModel`
"""
def __init__(self, name, dimension, source_dimension):
super().__init__(name, dimension, source_dimension)
[docs]
def update(self, names_of_changed_values, values):
"""
Update model group average parameter(s).
Parameters
----------
names_of_changed_values : set[str]
Elements of set must be either:
* ``all`` (update everything)
* ``g`` correspond to the attribute :attr:`positions`.
* ``v0`` (only for multivariate models) correspond to the attribute :attr:`velocities`.
When we are sure that the v0 change is only a scalar multiplication
(in particular, when we reparametrize log(v0) <- log(v0) + mean(xi)),
we may update velocities using ``v0_collinear``, otherwise
we always assume v0 is NOT collinear to previous value
(no need to perform the verification it is - would not be really efficient)
* ``betas`` correspond to the linear combination of columns from the orthonormal basis so
to derive the :attr:`mixing_matrix`.
values : dict [str, `torch.Tensor`]
New values used to update the model's group average parameters
Raises
------
:exc:`.LeaspyModelInputError`
If `names_of_changed_values` contains unknown parameters.
"""
self._check_names(names_of_changed_values)
compute_betas = False
compute_positions = False
compute_velocities = False
dgamma_t0_not_collinear_to_previous = False
if "all" in names_of_changed_values:
# make all possible updates
names_of_changed_values = self.update_possibilities
if "betas" in names_of_changed_values:
compute_betas = True
if "g" in names_of_changed_values:
compute_positions = True
if ("v0" in names_of_changed_values) or (
"v0_collinear" in names_of_changed_values
):
compute_velocities = True
dgamma_t0_not_collinear_to_previous = "v0" in names_of_changed_values
if compute_positions:
self._compute_positions(values)
if compute_velocities:
self._compute_velocities(values)
# only for models with sources beyond this point
if not self.has_sources:
return
if compute_betas:
self._compute_betas(values)
# do not recompute orthonormal basis when we know dgamma_t0 is collinear
# to previous velocities to avoid useless computations!
recompute_ortho_basis = compute_positions or dgamma_t0_not_collinear_to_previous
if recompute_ortho_basis:
self._compute_orthonormal_basis()
if recompute_ortho_basis or compute_betas:
self._compute_mixing_matrix()
def _compute_positions(self, values):
"""
Update the attribute ``positions``.
Parameters
----------
values : dict [str, `torch.Tensor`]
"""
self.positions = torch.exp(values["g"])
def _compute_orthonormal_basis(self):
"""
Compute the attribute ``orthonormal_basis`` which is an orthonormal basis, w.r.t the canonical inner product,
of the sub-space orthogonal, w.r.t the inner product implied by the metric, to the time-derivative of the geodesic at initial time.
"""
# Compute the diagonal of metric matrix (cf. `_compute_Q`)
G_metric = (1 + self.positions).pow(4) / self.positions.pow(
2
) # = "1/(p0 * (1-p0))**2"
dgamma_t0 = self.velocities
# Householder decomposition in non-Euclidean case, updates `orthonormal_basis` in-place
self._compute_Q(dgamma_t0, G_metric)