Source code for leaspy.models.obs_models._bernoulli

import torch

from leaspy.io.data.dataset import Dataset
from leaspy.utils.weighted_tensor import WeightedTensor
from leaspy.variables.distributions import Bernoulli
from leaspy.variables.specs import VariableInterface

from ._base import ObservationModel

__all__ = ["BernoulliObservationModel"]


[docs] class BernoulliObservationModel(ObservationModel): """ Observation model for binary outcomes using a Bernoulli distribution. This model expects binary-valued observations and uses a Bernoulli distribution to define the likelihood. It assumes the response variable is named `"y"`. Parameters ---------- **extra_vars : VariableInterface Optional extra variables required by the model. These are passed to the parent `ObservationModel` class and can be used for conditioning the likelihood. Attributes ---------- string_for_json : :obj:`str` A static string identifier used for serialization. """ string_for_json = "bernoulli" def __init__( self, **extra_vars: VariableInterface, ): super().__init__( name="y", getter=self.y_getter, dist=Bernoulli("model"), extra_vars=extra_vars, )
[docs] @staticmethod def y_getter(dataset: Dataset) -> WeightedTensor: """ Extracts and validates the observation values and associated mask from a dataset. Parameters ---------- dataset : :class:`.Dataset` A dataset object containing `values` and `mask` attributes. Returns ------- :class:`.WeightedTensor` A tensor containing the observed binary values along with a boolean mask indicating which entries are valid. Raises ------ ValueError If either `dataset.values` or `dataset.mask` is `None`, indicating that the dataset is improperly initialized. """ if dataset.values is None or dataset.mask is None: raise ValueError( "Provided dataset is not valid. " "Both values and mask should be not None." ) return WeightedTensor(dataset.values, weight=dataset.mask.to(torch.bool))