Source code for leaspy.algo.algo_with_device

import contextlib
import warnings

import torch

from leaspy.io.data import Dataset
from leaspy.models import McmcSaemCompatibleModel

from .settings import AlgorithmSettings

__all__ = ["AlgorithmWithDeviceMixin"]


[docs] class AlgorithmWithDeviceMixin: """Mixin class containing common attributes & methods for algorithms with a torch device. Parameters ---------- settings : :class:`.AlgorithmSettings` The specifications of the algorithm as a :class:`.AlgorithmSettings` instance. Attributes ---------- algorithm_device : :obj:`str` Valid torch device """ def __init__(self, settings: AlgorithmSettings): super().__init__(settings) self.algorithm_device = settings.device self._default_algorithm_device = torch.device("cpu") self._default_algorithm_tensor_type = "torch.FloatTensor" @contextlib.contextmanager def _device_manager(self, model: McmcSaemCompatibleModel, dataset: Dataset): """ Context-manager to handle the "ambient device" (i.e. the device used to instantiate tensors and perform computations). The provided model and dataset will be moved to the device specified for the execution at the beginning of the algorithm and moved back to the original ('cpu') device at the end of the algorithm. The default tensor type will also be set accordingly. Parameters ---------- model : :class:`~.models.abstract_model.McmcSaemCompatibleModel` The used model. dataset : :class:`.Dataset` Contains the subjects' observations in torch format to speed up computation. """ algorithm_tensor_type = self._default_algorithm_tensor_type if self.algorithm_device != self._default_algorithm_device.type: algorithm_device = torch.device(self.algorithm_device) model.move_to_device(algorithm_device) dataset.move_to_device(algorithm_device) algorithm_tensor_type = "torch.cuda.FloatTensor" try: with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=".*torch.set_default_tensor_type.*") yield torch.set_default_tensor_type(algorithm_tensor_type) finally: if self.algorithm_device != self._default_algorithm_device.type: model.move_to_device(self._default_algorithm_device) dataset.move_to_device(self._default_algorithm_device) with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=".*torch.set_default_tensor_type.*") torch.set_default_tensor_type(self._default_algorithm_tensor_type)