leaspy.variables.state¶
This module defines the State of stateful models.
A state contains 2 main components:
1. The relationships between the variables through a VariablesDAG
2. The values of each variable of the DAG as a mapping between variable names and their values
The State class is crucial for stateful models with its logic for efficiently retrieving variable values. This class relies on a caching mechanism that enables quick queries.
Classes¶
The strategy used to cache forked values in |
|
Dictionary of cached values corresponding to the stateless DAG instance. |
Module Contents¶
- class StateForkType(*args, **kwds)[source]¶
Bases:
enum.EnumThe strategy used to cache forked values in
State.- REFReference-based caching
Cached values are stored by reference, meaning no copying occurs. Mutating the original variables after caching will affect the cached version.
- COPYDeep copy-based caching
Cached values are stored via copy.deepcopy, ensuring they are independent of the originals.
Notes
Use REF for efficiency when you’re certain the original values will not be mutated after caching.
Use COPY to ensure isolation between the cached values and any subsequent modifications.
If using REF beware that values will NOT be copied (it only keeps references of values),
so do NOT mutate them directly or the behavior will be unexpected.
- REF¶
- COPY¶
- class State(dag, *, auto_fork_type=None)[source]¶
Bases:
collections.abc.MutableMappingDictionary of cached values corresponding to the stateless DAG instance.
- Parameters:
- dag
VariablesDAG The stateless DAG which state will hold values for.
- auto_fork_type
StateForkTypeor None (default) Refer to
StateForkTypeclass andauto_fork_type
- dag
- Attributes:
- dag
VariablesDAG The stateless DAG which the state instance will hold values for.
- auto_fork_type
StateForkTypeor None If not StateForkType.NONE each dictionary assignment will lead to the partial caching of previous value and all its children, so they can be reverted without computation. The exact caching strategy depends on flag (caching by reference or by copy) Can be manually set or via auto_fork context manager.
- _tracked_variables:ob:`set[:class:`~leaspy.variables.specs.VariableName`, …]
- _values
VariablesLazyValuesRW Private cache for values (computations are lazy thus some values may be None). All not None values are always self-consistent with respect to DAG dependencies.
- _last_forkNone or Optional[
VariablesLazyValuesRO] If not None, holds the previous partial state values so they may be .revert(). Automatically populated on assignment operations as soon as auto_fork_type is not NONE. Example: if you set a new value for a, then value of a and of all its children just before assignment are held until either reversion or a new assignment.
- dag
- Parameters:
dag (VariablesDAG)
auto_fork_type (Optional[StateForkType])
- dag¶
- auto_fork_type = None¶
- property tracked_variables: set[VariableName, Ellipsis]¶
Get the set of variable names currently tracked by the State.
- Returns:
- ;obj:set`[:class:`~leaspy.variables.specs.VariableName, …]
A set containing the names of the tracked variables.
- Return type:
set[VariableName, Ellipsis]
- track_variables(variable_names)[source]¶
Add some variables to the tracked variables.
- Parameters:
- variable_names
IterableofVariableName The names of the variables to be added to the tracked variables.
- variable_names
- Parameters:
variable_names (Iterable[VariableName])
- Return type:
None
- track_variable(variable_name)[source]¶
Add a single variable to the tracked variables.
- Parameters:
- variable_name
VariableName The name of the variable to be added to the tracked variables.
- variable_name
- Parameters:
variable_name (VariableName)
- Return type:
None
- untrack_variables(variable_names)[source]¶
Remove some variables from the tracked variables.
- Parameters:
- variable_names
IterableofVariableName The names of the variables to be removed from the tracked variables.
- variable_names
- Parameters:
variable_names (Iterable[VariableName])
- Return type:
None
- untrack_variable(variable_name)[source]¶
Remove a single variable from the tracked variables.
- Parameters:
- variable_name
VariableName The name of the variable to be removed from the tracked variables.
- variable_name
- Parameters:
variable_name (VariableName)
- Return type:
None
- clear()[source]¶
Reset last forked state and reset all values to their canonical values.
- Return type:
None
- clone(*, disable_auto_fork=False, keep_last_fork=False)[source]¶
Clone current state without copying the DAG.
- auto_fork(type=StateForkType.REF)[source]¶
Provide a context manager interface with temporary auto_fork_type set to type.
- Parameters:
- type
StateForkTypeor None, optional The temporary auto-forking strategy to use within the context. Defaults to StateForkType.REF.
- type
- Yields:
- None
Control returns to the caller with the temporary forking strategy applied.
- Parameters:
type (Optional[StateForkType])
- is_variable_set(name)[source]¶
Returns True if the variable is in the DAG and if its value is not None.
- Parameters:
- name
VariableName The name of the variable to check.
- name
- Returns:
boolTrue if the variable exists in the DAG and its value has been set (i.e., is not None). False otherwise.
- Parameters:
name (VariableName)
- Return type:
- are_variables_set(variable_names)[source]¶
Returns True if all the variables are in the DAG with values different from None.
- Parameters:
- variable_names
Iterable`[:class:`~leaspy.variables.specs.VariableName] A collection of variable names to check.
- variable_names
- Returns:
boolTrue if all variables exist in the DAG and their values are set (i.e., not None). False otherwise.
- Parameters:
variable_names (Iterable[VariableName])
- Return type:
- put(variable_name, variable_value, *, indices=(), accumulate=False)[source]¶
Smart and protected assignment of a variable value, but potentially on a subset of indices, adding (accumulating) values and OUT-OF-PLACE.
- Parameters:
- variable_name
VariableName The name of the variable.
- variable_value
VariableValue The new value to put in the variable name.
- indices
tupleofint, optional If set, the operation will happen on a subset of indices. Default=()
- accumulate
bool, optional If set to True, the new variable value will be added to the old value. Otherwise, it will be assigned. Default=False
- variable_name
- Parameters:
variable_name (VariableName)
variable_value (VariableValue)
accumulate (bool)
- Return type:
None
- precompute_all()[source]¶
Pre-compute all values of the graph (assuming leaves already have valid values).
- Return type:
None
- revert(subset=None, *, right_broadcasting=True)[source]¶
Revert state to previous forked state. Forked state is then reset.
- Parameters:
- subset
VariableValueor None If not None, the reversion is only partial: * subset = True <=> revert previous state for those indices * subset = False <=> keep current state for those indices <!> User is responsible for having tensor values that are consistent with subset shape (i.e. valid broadcasting) for the forked node and all of its children.
<!> When the current OR forked state is not set (value = None) on a particular node of forked DAG, then the reverted result is always None.
- right_broadcasting
bool, optional If True and if subset is not None, then the subset of indices to revert uses right-broadcasting, instead of the standard left-broadcasting. Default=True.
- subset
- Raises:
LeaspyInputErrorIf no forked state exists to revert from (i.e., .auto_fork() context was not used).
- Parameters:
subset (Optional[VariableValue])
right_broadcasting (bool)
- Return type:
None
- to_device(device)[source]¶
Move values to the specified device (in-place).
- Parameters:
- device
torch.device
- device
- Parameters:
device (device)
- Return type:
None
- put_population_latent_variables(method)[source]¶
” Initialize all population latent variables in the state with predefined values.
- Parameters:
- methodobj:str or
LatentVariableInitTypeor None The method used to initialize the variables. If None, all population latent variables will be unset (set to None). Otherwise, the corresponding initialization function will be called for each variable using the provided method.
- methodobj:str or
- Parameters:
method (Optional[Union[str, LatentVariableInitType]])
- Return type:
None
- put_individual_latent_variables(method=None, *, n_individuals=None, df=None)[source]¶
Initialize all individual latent variables in the state with predefined values.
- Parameters:
- method
strorLatentVariableInitType, optional The method used to initialize the variables. If None, the variables will be unset (set to None). If provided, an initialization function will be called per variable. When method is not None, n_individuals must be specified.
- n_individuals:obl:`int`, optional
Number of individuals to initialize. Required when method is not None and df is not provided.
- df
pandas.DataFrame, optional A DataFrame from which to directly extract the individual latent variable values. It must contain columns named ‘tau’ and ‘xi’ for direct assignment of these variables. If the “sources” variable is present, the DataFrame should include columns named ‘sources_0’, ‘sources_1’, …, up to the expected number of source variables.
- method
- Raises:
LeaspyInputErrorIf method is specified without n_individuals, or if required columns are missing in df.
- Parameters:
method (Optional[Union[str, LatentVariableInitType]])
n_individuals (Optional[int])
df (Optional[DataFrame])
- Return type:
None
- save(output_folder, iteration=None)[source]¶
Save the tracked variable values of the state.
- Parameters:
- Parameters:
- Return type:
None
- get_tensor_value(variable_name)[source]¶
Return the value of the provided variable as a torch tensor.
- Parameters:
- variable_name
VariableName The name of the variable for which to retrieve the value.
- variable_name
- Returns:
VariableValueThe value of the variable.
- Parameters:
variable_name (VariableName)
- Return type:
- get_tensor_values(variable_names)[source]¶
Return the values of the provided variables as torch tensors.
- Parameters:
- variable_names
IterableofVariableName The names of the variables for which to retrieve the values.
- variable_names
- Returns:
tupleofVariableValueThe values of the variables.
- Parameters:
variable_names (Iterable[VariableName])
- Return type:
tuple[VariableValue, Ellipsis]