Source code for leaspy.samplers.factory
from typing import Type, Union
from leaspy.exceptions import LeaspyAlgoInputError
from leaspy.variables.specs import IndividualLatentVariable, PopulationLatentVariable
from .base import (
AbstractIndividualSampler,
AbstractPopulationSampler,
AbstractSampler,
)
from .gibbs import (
IndividualGibbsSampler,
PopulationFastGibbsSampler,
PopulationGibbsSampler,
PopulationMetropolisHastingsSampler,
)
__all__ = [
"SamplerFactoryInput",
"INDIVIDUAL_SAMPLERS",
"POPULATION_SAMPLERS",
"sampler_factory",
]
SamplerFactoryInput = Union[str, AbstractSampler]
INDIVIDUAL_SAMPLERS = {"gibbs": IndividualGibbsSampler}
POPULATION_SAMPLERS = {
"gibbs": PopulationGibbsSampler,
"fastgibbs": PopulationFastGibbsSampler,
"metropolis-hastings": PopulationMetropolisHastingsSampler,
}
[docs]
def sampler_factory(
sampler: SamplerFactoryInput, variable_type, **kwargs
) -> AbstractSampler:
"""
Factory for Samplers.
Parameters
----------
sampler : :class:`.AbstractSampler` or :obj:`str`
If an instance of a subclass of :class:`.AbstractSampler`, returns the instance.
If a string, returns a new instance of the appropriate class (with optional parameters `kwargs`).
variable_type : :class:`.VariableType`
The type of random variable that the sampler is supposed to sample.
**kwargs
Optional parameters for initializing the requested Sampler
(not used if input is a subclass of :class:`.AbstractSampler`).
Returns
-------
:class:`.AbstractSampler` :
The desired sampler.
Raises
------
:exc:`.LeaspyAlgoInputError`:
If the sampler provided is not supported.
"""
if isinstance(sampler, AbstractSampler):
return sampler
if isinstance(sampler, str):
kls = _get_sampler_class(sampler, variable_type)
return kls(**kwargs)
raise LeaspyAlgoInputError(
"The provided `sampler` should be a valid instance of `AbstractSampler`, or a string "
f"among {set(INDIVIDUAL_SAMPLERS).union(POPULATION_SAMPLERS)}."
)
def _get_sampler_class(sampler_name: str, variable_type):
if variable_type == IndividualLatentVariable:
return _get_individual_sampler_class(sampler_name)
if variable_type == PopulationLatentVariable:
return _get_population_sampler_class(sampler_name)
def _get_individual_sampler_class(sampler_name: str) -> Type[AbstractIndividualSampler]:
sampler_name = sampler_name.lower().replace("_", "-")
kls = INDIVIDUAL_SAMPLERS.get(sampler_name, None)
if kls is None:
raise LeaspyAlgoInputError(
f"Individual sampler '{sampler_name}' is not supported. "
f"Supported samplers for individual variables are {set(INDIVIDUAL_SAMPLERS)}"
)
return kls
def _get_population_sampler_class(sampler_name: str) -> Type[AbstractPopulationSampler]:
sampler_name = sampler_name.lower().replace("_", "-")
kls = POPULATION_SAMPLERS.get(sampler_name, None)
if kls is None:
raise LeaspyAlgoInputError(
f"Population sampler '{sampler_name}' is not supported. "
f"Supported samplers for population variables are {set(POPULATION_SAMPLERS)}"
)
return kls