Source code for leaspy.models.obs_models._weibull

from leaspy.io.data.dataset import Dataset
from leaspy.utils.weighted_tensor import WeightedTensor
from leaspy.variables.distributions import (
    WeibullRightCensored,
    WeibullRightCensoredWithSources,
)
from leaspy.variables.specs import (
    LinkedVariable,
    VariableInterface,
    VariableName,
)

from ._base import ObservationModel

__all__ = [
    "AbstractWeibullRightCensoredObservationModel",
    "WeibullRightCensoredObservationModel",
    "WeibullRightCensoredWithSourcesObservationModel",
]


[docs] class AbstractWeibullRightCensoredObservationModel(ObservationModel):
[docs] @staticmethod def getter(dataset: Dataset) -> WeightedTensor: if dataset.event_time is None or dataset.event_bool is None: raise ValueError( "Provided dataset is not valid. " "Both values and mask should be not None." ) return WeightedTensor(dataset.event_time, dataset.event_bool)
[docs] def get_variables_specs( self, named_attach_vars: bool = True, ) -> dict[VariableName, VariableInterface]: """Automatic specifications of variables for this observation model.""" specs = super().get_variables_specs(named_attach_vars) specs[f"predictions_{self.name}"] = LinkedVariable( self.dist.get_func("compute_predictions", self.name) ) return specs
[docs] class WeibullRightCensoredObservationModel( AbstractWeibullRightCensoredObservationModel ): string_for_json = "weibull-right-censored" def __init__( self, nu: VariableName, rho: VariableName, xi: VariableName, tau: VariableName, **extra_vars: VariableInterface, ): super().__init__( name="event", getter=self.getter, dist=WeibullRightCensored(nu, rho, xi, tau), extra_vars=extra_vars, )
[docs] @classmethod def default_init(self, **kwargs): return self( nu=kwargs.pop("nu", "nu"), rho=kwargs.pop("rho", "rho"), xi=kwargs.pop("xi", "xi"), tau=kwargs.pop("tau", "tau"), )
[docs] class WeibullRightCensoredWithSourcesObservationModel( AbstractWeibullRightCensoredObservationModel ): string_for_json = "weibull-right-censored-with-sources" def __init__( self, nu: VariableName, rho: VariableName, xi: VariableName, tau: VariableName, survival_shifts: VariableName, **extra_vars: VariableInterface, ): super().__init__( name="event", getter=self.getter, dist=WeibullRightCensoredWithSources(nu, rho, xi, tau, survival_shifts), extra_vars=extra_vars, )
[docs] @classmethod def default_init(self, **kwargs): return self( nu=kwargs.pop("nu", "nu"), rho=kwargs.pop("rho", "rho"), xi=kwargs.pop("xi", "xi"), tau=kwargs.pop("tau", "tau"), survival_shifts=kwargs.pop("survival_shifts", "survival_shifts"), )