import warnings
from abc import abstractmethod
from typing import Iterable, Optional
import torch
from leaspy.exceptions import LeaspyModelInputError
from leaspy.io.data.dataset import Dataset
from leaspy.utils.typing import DictParamsTorch, KwargsType
from leaspy.utils.weighted_tensor import WeightedTensor
from leaspy.variables.dag import VariablesDAG
from leaspy.variables.specs import (
Hyperparameter,
IndividualLatentVariable,
LatentVariableInitType,
ModelParameter,
NamedVariables,
PopulationLatentVariable,
VariableName,
VariableNameToValueMapping,
)
from leaspy.variables.state import State, StateForkType
from .base import BaseModel
__all__ = ["StatefulModel"]
[docs]
class StatefulModel(BaseModel):
"""Stateful models have an internal :class:`~leaspy.variables.State` to handle parameters and variables.
Parameters
----------
name : :obj:`str`
The name of the model.
Attributes
----------
state : :class:`~leaspy.variables.State`
The internal state of the model, which contains the variables and their values.
tracked_variables : :obj:`set` [:obj:`str`]
Set of variable names that are tracked by the model. These variables are not necessarily part of
the model's state but are monitored for changes or updates. This can include variables that are
relevant for the model's operation but not directly stored in the state.
"""
def __init__(self, name: str, **kwargs):
super().__init__(name, **kwargs)
self._state: Optional[State] = None
self.tracked_variables: set[str] = set()
[docs]
def track_variable(self, variable: VariableName) -> None:
"""Track a variable by its name.
Parameters
-------
variable : :class:`~leaspy.variables.specs.VariableName`
The name of the variable to track. This variable will be monitored for changes or updates.
"""
self.tracked_variables.add(variable)
[docs]
def track_variables(self, variables: Iterable[VariableName]) -> None:
"""
Track multiple variables by their names.
Parameters
----------
variables : :obj:`Iterable` [:class:`~leaspy.variables.specs.VariableName`]
An iterable containing the names of the variables to track. Each variable will be monitored
for changes or updates.
"""
for variable in variables:
self.track_variable(variable)
[docs]
def untrack_variable(self, variable: VariableName) -> None:
"""Untrack a variable by its name.
Parameters
-------
variable : :class:`~leaspy.variables.specs.VariableName`
The name of the variable to untrack. This variable will no longer be monitored for changes or updates.
"""
self.tracked_variables.remove(variable)
[docs]
def untrack_variables(self, variables: Iterable[VariableName]) -> None:
"""Untrack multiple variables by their names.
Parameters
----------
variables : :obj:`Iterable` [:class:`~leaspy.variables.specs.VariableName`]
An iterable containing the names of the variables to untrack. Each variable will no longer be monitored
for changes or updates.
"""
for variable in variables:
self.untrack_variable(variable)
@property
def state(self) -> State:
"""Get the internal state of the model.
Returns
-------
State : :class:`~leaspy.variables.State`
The internal state of the model, which contains the variables and their values.
Raises
------
:exc:`.LeaspyModelInputError`
If the model's state is not initialized yet.
"""
if self._state is None:
raise LeaspyModelInputError("Model state is not initialized yet")
return self._state
@state.setter
def state(self, s: State) -> None:
"""Set the internal state of the model.
This method allows to set the internal state of the model, which is an instance of
Parameters
----------
s : :class:`~leaspy.variables.State`
The new state to set for the model.
Raises
------
LeaspyModelInputError
If the provided state does not match the previous state in terms of DAG structure.
"""
assert isinstance(s, State), "Provided state should be a valid State instance"
if self._state is not None and s.dag is not self._state.dag:
raise LeaspyModelInputError(
"DAG of new state does not match with previous one"
)
# TODO? perform some clean-up steps for provided state (cf. `_terminate_algo` of MCMC algo)
self._state = s
@property
def dag(self) -> VariablesDAG:
"""Get the underlying DAG of the model's state.
Returns
-------
: :class:`~leaspy.variables.dag.VariablesDAG`
The directed acyclic graph (DAG) representing the model's variables and their relationships
"""
return self.state.dag
@property
def hyperparameters_names(self) -> tuple[VariableName, ...]:
"""Get the names of the model's hyperparameters.
Returns
-------
: :obj:`tuple` [:class:`~leaspy.variables.specs.VariableName`, others...]
A tuple containing the names of the model's hyperparameters.
"""
return tuple(self.dag.sorted_variables_by_type[Hyperparameter])
@property
def parameters_names(self) -> tuple[VariableName, ...]:
"""Get the names of the model's parameters.
Returns
-------
: :obj:`tuple` [:class:`~leaspy.variables.specs.VariableName`, others...]
A tuple containing the names of the model's parameters.
"""
return tuple(self.dag.sorted_variables_by_type[ModelParameter])
@property
def population_variables_names(self) -> tuple[VariableName, ...]:
"""Get the names of the population latent variables.
Returns
-------
: :obj:`tuple` [:class:`~leaspy.variables.specs.VariableName`, ...]
A tuple containing the names of the population latent variables.
"""
return tuple(self.dag.sorted_variables_by_type[PopulationLatentVariable])
@property
def individual_variables_names(self) -> tuple[VariableName, ...]:
"""Get the names of the individual latent variables.
Returns
-------
: :obj:`tuple` [:class:`~leaspy.variables.specs.VariableName`, ...]
A tuple containing the names of the individual latent variables.
"""
return tuple(self.dag.sorted_variables_by_type[IndividualLatentVariable])
@property
def parameters(self) -> DictParamsTorch:
"""Dictionary of values for model parameters.
Returns
-------
: :class:`~leaspy.utils.typing.DictParamsTorch`
A dictionary mapping parameter names to their values (as tensors).
"""
return {p: self._state[p] for p in self.parameters_names}
@property
def hyperparameters(self) -> DictParamsTorch:
"""Dictionary of values for model hyperparameters.
Returns
-------
: :class:`~leaspy.utils.typing.DictParamsTorch`
A dictionary mapping hyperparameter names to their values (as tensors).
"""
return {p: self._state[p] for p in self.hyperparameters_names}
[docs]
def initialize(self, dataset: Optional[Dataset] = None) -> None:
"""Overloads base model initialization (in particular to handle internal model State).
<!> We do not put data variables in internal model state at this stage (done in algorithm)
Parameters
----------
dataset : :class:`~leaspy.io.data.dataset.Dataset`, optional
Input dataset from which to initialize the model.
"""
super().initialize(dataset=dataset)
self._initialize_state()
if not dataset:
return
# WIP: design of this may be better somehow?
with self._state.auto_fork(None):
self._initialize_model_parameters(dataset)
self._state.put_population_latent_variables(
LatentVariableInitType.PRIOR_MODE
)
def _initialize_state(self) -> None:
"""Initialize the internal state of model, as well as the underlying DAG.
Note that all model hyperparameters (dimension, source_dimension, ...) should be defined
in order to be able to do so.
"""
if self._state is not None:
raise LeaspyModelInputError("Trying to initialize the model's state again")
self.state = State(
VariablesDAG.from_dict(self.get_variables_specs()),
auto_fork_type=StateForkType.REF,
)
self.state.track_variables(self.tracked_variables)
def _initialize_model_parameters(self, dataset: Dataset) -> None:
"""Initialize model parameters (in-place, in `_state`).
The method also checks that the model parameters whose initial values
were computed from the dataset match the expected model parameters from
the specifications (i.e. the nodes of the DAG of type 'ModelParameter').
If there is a mismatch, the method raises a ValueError because there is
an inconsistency between the definition of the model and the way it computes
the initial values of its parameters from a dataset.
Parameters
----------
dataset : :class:`~leaspy.io.data.dataset.Dataset`
The dataset to use to compute initial values for the model parameters.
"""
model_parameters_initialization = (
self._compute_initial_values_for_model_parameters(dataset)
)
model_parameters_spec = self.dag.sorted_variables_by_type[ModelParameter]
if set(model_parameters_initialization.keys()) != set(model_parameters_spec):
raise ValueError(
"Model parameters created at initialization are different "
"from the expected model parameters from the specs:\n"
f"- From initialization: {sorted(list(model_parameters_initialization.keys()))}\n"
f"- From Specs: {sorted(list(model_parameters_spec))}\n"
)
for (
model_parameter_name,
model_parameter_variable,
) in model_parameters_spec.items():
model_parameter_initial_value = model_parameters_initialization[
model_parameter_name
]
if not isinstance(
model_parameter_initial_value, (torch.Tensor, WeightedTensor)
):
try:
model_parameter_initial_value = torch.tensor(
model_parameter_initial_value, dtype=torch.float
)
except ValueError:
raise ValueError(
f"The initial value for model parameter '{model_parameter_name}' "
"should be a tensor, or a weighted tensor.\nInstead, "
f"{model_parameter_initial_value} of type {type(model_parameter_initial_value)} "
"was received and cannot be casted to a tensor.\nPlease verify this parameter "
"initialization code."
)
self._state[model_parameter_name] = model_parameter_initial_value.expand(
model_parameter_variable.shape
)
[docs]
def load_parameters(self, parameters: KwargsType) -> None:
"""Instantiate or update the model's parameters.
It assumes that all model hyperparameters are defined.
Parameters
----------
parameters : :class:`~leaspy.utils.typing.KwargsType`]
Contains the model's parameters.
"""
from .utilities import val_to_tensor
if self._state is None:
self._initialize_state()
# TODO: a bit dirty due to hyperparams / params mix (cf. `.parameters` property note)
params_names = self.parameters_names
missing_params = set(params_names).difference(parameters)
if len(missing_params):
warnings.warn(f"Missing some model parameters: {missing_params}")
extra_vars = set(parameters).difference(self.dag)
if len(extra_vars):
raise LeaspyModelInputError(f"Unknown model variables: {extra_vars}")
# TODO: check no DataVariable provided???
# extra_params = set(parameters).difference(cur_params)
# if len(extra_params):
# # e.g. mixing matrix, which is a derived variable - checking their values only
# warnings.warn(f"Ignoring some provided values that are not model parameters: {extra_params}")
# update parameters first (to be able to check values of derived variables afterwards)
provided_params = {
p: val_to_tensor(parameters[p], self.dag[p].shape)
for p in params_names
if p in parameters
}
for p, val in provided_params.items():
# TODO: WeightedTensor? (e.g. batched `deltas`)
self._state[p] = val
# derive the population latent variables from model parameters
# e.g. to check value of `mixing_matrix` we need `v0` and `betas` (not just `log_v0` and `betas_mean`)
self._state.put_population_latent_variables(LatentVariableInitType.PRIOR_MODE)
# check equality of other values (hyperparameters or linked variables)
for parameter_name, parameter_value in parameters.items():
if parameter_name in provided_params:
continue
# TODO: a bit dirty due to hyperparams / params mix (cf. `.parameters` property note)
try:
current_value = self._state[parameter_name]
except Exception as e:
raise LeaspyModelInputError(
f"Impossible to compare value of provided value for {parameter_name} "
"- not computable given current state"
) from e
parameter_value = val_to_tensor(
parameter_value, getattr(self.dag[parameter_name], "shape", None)
)
assert (
parameter_value.shape == current_value.shape,
(parameter_name, parameter_value.shape, current_value.shape),
)
# TODO: WeightedTensor? (e.g. batched `deltas``)
assert (
torch.allclose(parameter_value, current_value, atol=1e-4),
(parameter_name, parameter_value, current_value),
)
[docs]
@abstractmethod
def get_variables_specs(self) -> NamedVariables:
"""Return the specifications of the variables (latent variables,
derived variables, model 'parameters') that are part of the model.
Returns
-------
NamedVariables : :class:`~leaspy.variables.specs.NamedVariables`
The specifications of the model's variables.
"""
raise NotImplementedError()
@abstractmethod
def _compute_initial_values_for_model_parameters(
self, dataset: Dataset
) -> VariableNameToValueMapping:
"""Compute initial values for model parameters.
Parameters
----------
dataset : :class:`~leaspy.io.data.dataset.Dataset`
The dataset to use to compute initial values for the model parameters.
Returns
-------
: :class:`~leaspy.utils.typing.Any`
A dictionary mapping variable names to their initial values.
"""
raise NotImplementedError()
[docs]
def move_to_device(self, device: torch.device) -> None:
"""Move a model and its relevant attributes to the specified :class:`torch.device`.
Parameters
----------
device : :obj:`torch.device`
The device to which the model and its attributes should be moved.
"""
if self._state is None:
return
self._state.to_device(device)
for hp in self.hyperparameters_names:
self._state.dag[hp].to_device(device)