Source code for leaspy.models.obs_models._factory

"""Defines the noise model factory."""

from enum import Enum
from typing import Dict, Type, Union

from leaspy.exceptions import LeaspyModelInputError

from ._base import ObservationModel
from ._bernoulli import BernoulliObservationModel
from ._gaussian import FullGaussianObservationModel
from ._weibull import (
    WeibullRightCensoredObservationModel,
    WeibullRightCensoredWithSourcesObservationModel,
)

__all__ = [
    "ObservationModelNames",
    "ObservationModelFactoryInput",
    "OBSERVATION_MODELS",
    "observation_model_factory",
]


[docs] class ObservationModelNames(Enum): """Enumeration defining the possible names for observation models.""" GAUSSIAN_DIAGONAL = "gaussian-diagonal" GAUSSIAN_SCALAR = "gaussian-scalar" BERNOULLI = "bernoulli" WEIBULL_RIGHT_CENSORED = "weibull-right-censored" WEIBULL_RIGHT_CENSORED_WITH_SOURCES = "weibull-right-censored-with-sources"
[docs] @classmethod def from_string(cls, model_name: str): try: return cls(model_name.lower().replace("_", "-")) except ValueError: raise NotImplementedError( f"The requested ObservationModel {model_name} is not implemented. " f"Valid observation model names are: {[elt.value for elt in cls]}." )
ObservationModelFactoryInput = Union[str, ObservationModelNames, ObservationModel] OBSERVATION_MODELS: Dict[ObservationModelNames, Type[ObservationModel]] = { ObservationModelNames.GAUSSIAN_DIAGONAL: FullGaussianObservationModel, ObservationModelNames.GAUSSIAN_SCALAR: FullGaussianObservationModel, ObservationModelNames.BERNOULLI: BernoulliObservationModel, ObservationModelNames.WEIBULL_RIGHT_CENSORED: WeibullRightCensoredObservationModel, ObservationModelNames.WEIBULL_RIGHT_CENSORED_WITH_SOURCES: WeibullRightCensoredWithSourcesObservationModel, }
[docs] def observation_model_factory( model: ObservationModelFactoryInput, **kwargs ) -> ObservationModel: """ Factory for observation models. Parameters ---------- model : :obj:`str` or :class:`.ObservationModel` or :obj:`dict` [ :obj:`str`, ...] - If an instance of a subclass of :class:`.ObservationModel`, returns the instance. - If a string, then returns a new instance of the appropriate class (with optional parameters `kws`). - If a dictionary, it must contain the 'name' key and other initialization parameters. **kwargs Optional parameters for initializing the requested observation model when a string. Returns ------- :class:`.ObservationModel` : The desired observation model. Raises ------ :exc:`.LeaspyModelInputError` : If `model` is not supported. """ dimension = kwargs.pop("dimension", None) n_clusters = kwargs.pop("n_clusters", None) if isinstance(model, ObservationModel): return model if isinstance(model, str): model = ObservationModelNames.from_string(model) if isinstance(model, ObservationModelNames): if model == ObservationModelNames.GAUSSIAN_DIAGONAL: if dimension is None: raise NotImplementedError( "WIP: dimension / features should be provided to " f"init the obs_model = {ObservationModelNames.GAUSSIAN_DIAGONAL}." ) return FullGaussianObservationModel.with_noise_std_as_model_parameter( dimension ) if model == ObservationModelNames.GAUSSIAN_SCALAR: return FullGaussianObservationModel.with_noise_std_as_model_parameter(1) if model == ObservationModelNames.WEIBULL_RIGHT_CENSORED: return WeibullRightCensoredObservationModel.default_init(kwargs=kwargs) if model == ObservationModelNames.WEIBULL_RIGHT_CENSORED_WITH_SOURCES: return WeibullRightCensoredWithSourcesObservationModel.default_init( kwargs=kwargs ) return OBSERVATION_MODELS[model](**kwargs) raise LeaspyModelInputError( "The provided `model` should be a valid instance of `ObservationModel`, " f"or a string among {[c.value for c in ObservationModelNames]}." f"Instead, {model} of type {type(model)} was provided." )