Source code for leaspy.algo.fit.fit_output_manager

import re
import time
from pathlib import Path
from typing import Iterable, Optional

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib import colormaps
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.lines import Line2D

from leaspy.io.data import Dataset
from leaspy.models import McmcSaemCompatibleModel

from .base import FitAlgorithm


[docs] class FitOutputManager: """ Class used by :class:`~leaspy.algo.AbstractAlgo` (and its child classes) to display & save plots and statistics during algorithm execution. Parameters ---------- outputs : :class:`~leaspy.algo.OutputsSettings` Initialize the `FitOutputManager` class attributes, like the logs paths, the console print periodicity and so forth. Attributes ---------- path_output : :obj:`str` Path of the folder containing all the outputs path_plot : :obj:`str` Path of the subfolder of path_output containing the logs plots path_plot_convergence_model_parameters : :obj:`str` Path of the first plot of the convergence of the model's parameters (in the subfolder path_plot) path_plot_patients : :obj:`str` Path of the subfolder of path_plot containing the plot of the reconstruction of the patients' longitudinal trajectory by the model nb_of_patients_to_plot : :obj:`int` Number of patients for whom the reconstructions will be plotted. path_save_model_parameters_convergence : :obj:`str` Path of the subfolder of path_output containing the progression of the model's parameters convergence periodicity_plot : :obj:`int` (default 100) Set the frequency of the display of the plots periodicity_print : :obj:`int` Set the frequency of the display of the statistics periodicity_save : :obj:`int` Set the frequency of the saves of the model's parameters periodicity_plot_patients : :obj:`int` Set the frequency of the saves of the patients' reconstructions plot_sourcewise : :obj:`bool` If True, plots will be generated for each source separately. """ def __init__(self, outputs): self.periodicity_print = outputs.print_periodicity self.periodicity_save = outputs.save_periodicity self.periodicity_plot = outputs.plot_periodicity self.nb_of_patients_to_plot = outputs.nb_of_patients_to_plot self.periodicity_plot_patients = outputs.plot_patient_periodicity self.plot_sourcewise = outputs.plot_sourcewise if outputs.root_path is not None: self.path_output = Path(outputs.root_path) self.path_plot = Path(outputs.plot_path) self.path_plot_patients = Path(outputs.patients_plot_path) self.path_save_model_parameters_convergence = Path( outputs.parameter_convergence_path ) self.path_plot_convergence_model_parameters = ( self.path_plot / "convergence_parameters.pdf" ) self.time = time.time()
[docs] def iteration( self, algo: FitAlgorithm, model: McmcSaemCompatibleModel, data: Dataset ) -> None: """ Call methods to save state of the running computation, display statistics & plots if the current iteration is a multiple of `periodicity_print`, `periodicity_plot` or `periodicity_save` Parameters ---------- algo : :class:`~leaspy.algo.fit.FitAlgo` A fitting algorithm. model : :class:`~leaspy.models.McmcSaemCompatibleModel` The model used by the computation. data : :class:`.Dataset` The data used by the computation """ # <!> only `current_iteration` defined for AbstractFitAlgorithm... TODO -> generalize where possible? if not hasattr(algo, "current_iteration"): # emit a warning? return iteration = algo.current_iteration if self.path_output is None: return if self.periodicity_print is not None: if iteration == 0 or iteration % self.periodicity_print == 0: self.print_algo_statistics(algo) self.print_model_statistics(model) self.print_time() if self.periodicity_save is not None: if iteration == 0 or iteration % self.periodicity_save == 0: self.save_model_parameters_convergence(iteration, model) if self.periodicity_plot_patients is not None: if iteration == 0 or iteration % self.periodicity_plot_patients == 0: self.save_plot_patient_reconstructions(iteration, model, data) if self.periodicity_plot is not None: if iteration % self.periodicity_plot == 0: self.save_plot_convergence_model_parameters(model)
[docs] def print_time(self): """Prints the duration since the last periodic point.""" current_time = time.time() print(f"Duration since last print: {current_time - self.time:.3f}s") self.time = current_time
[docs] def print_model_statistics(self, model: McmcSaemCompatibleModel): """Prints model's statistics. Parameters ---------- model : :class:`~leaspy.models.McmcSaemCompatibleModel` The model used by the computation """ print(model)
[docs] def print_algo_statistics(self, algo: FitAlgorithm): """Prints algorithm's statistics Parameters ---------- algo : :class:`~leaspy.algo.fit.FitAlgo` A fitting algorithm. """ print(algo)
[docs] def save_model_parameters_convergence( self, iteration: int, model: McmcSaemCompatibleModel ) -> None: """Saves the current state of the model's parameters Parameters ---------- iteration : :obj:`int` The current iteration. model : :class:`~.models.abstract_model.McmcSaemCompatibleModel` The model used by the computation """ model.state.save( str(self.path_save_model_parameters_convergence), iteration=iteration, )
[docs] def save_plot_convergence_model_parameters(self, model: McmcSaemCompatibleModel): """Saves figures of the model parameters' convergence in multiple pages of a PDF. Parameters ---------- model : :class:`~leaspy.models.McmcSaemCompatibleModel` The model used by the computation """ width = 10 height_per_row = 3.5 to_skip = {"betas", "sources", "space_shifts", "xi", "tau", "xi_mean"} if model.name == "ordinal": to_skip.add("deltas") params_with_feature_labels = ["g", "v0"] params_with_sources = ["mixing_matrix"] params_with_events = [] if model.name == "joint": to_skip.add("survival_shifts") params_with_sources.append("zeta") params_with_events += ["nu", "rho"] params_to_plot = list(model.state.tracked_variables - to_skip) files_to_plot = self._get_files_related_to_parameters(params_to_plot) # To plot related parameters close to each other, we sort the list files_to_plot.sort() # If plot sourcewise is true, new sourcewise csv files will be created if self.plot_sourcewise and hasattr(model, "source_dimension"): files_to_plot = [ file for file in files_to_plot if not any(file.name.startswith(param) for param in params_with_sources) ] files_to_plot.extend( self._compute_files_sourcewise( params_with_sources, model.source_dimension ) ) n_plots = len(files_to_plot) with PdfPages(self.path_plot_convergence_model_parameters) as pdf: for page in range(0, n_plots, 6): # 6 plots per page _, ax = plt.subplots(3, 2, figsize=(width, 3 * height_per_row)) ax = ax.flatten() for i, file_path in enumerate(files_to_plot[page : page + 6]): parameter_name = file_path.name.split(".csv")[0] parameter_name, index = self._extract_parameter_name_and_index( parameter_name ) df_convergence = pd.read_csv(file_path, index_col=0, header=None) ax[i].plot(df_convergence) ax = self._set_title_for_parameter( ax, i, parameter_name, model, index ) ax = self._set_legend_for_parameters( ax, i, parameter_name, params_with_feature_labels, params_with_sources, params_with_events, model, ) plt.tight_layout() pdf.savefig() plt.close()
[docs] def save_plot_patient_reconstructions( self, iteration: int, model: McmcSaemCompatibleModel, data: Dataset, ) -> None: """ Saves figures of real longitudinal values and their reconstructions computed by the model for maximum 5 patients during each iteration. Parameters ---------- iteration : :obj:`int` The current iteration model : :class:`~leaspy.models.McmcSaemCompatibleModel` The model used by the computation data : :class:`~leaspy.io.data.Dataset` The dataset used by the computation """ number_of_patient_plot = min(self.nb_of_patients_to_plot, data.n_individuals) individual_parameters_dict = { variable: model.state.get_tensor_value(variable) for variable in model.individual_variables_names } colors = colormaps["Dark2"](np.linspace(0, 1, number_of_patient_plot + 2)) fig, ax = plt.subplots(1, 1) ax.set_title(f"Feature trajectory for {number_of_patient_plot} patients") ax.set_xlabel("Ages") ax.set_ylabel("Normalized Feature Value") for i in range(number_of_patient_plot): times_patient = data.get_times_patient(i).cpu().detach().numpy() true_values_patient = data.get_values_patient(i).cpu().detach().numpy() ip_patient = {pn: pv[i] for pn, pv in individual_parameters_dict.items()} reconstruction_values_patient = ( model.compute_individual_trajectory(times_patient, ip_patient) .squeeze(0) .numpy() ) ax.plot(times_patient, reconstruction_values_patient, c=colors[i]) ax.plot( times_patient, true_values_patient, c=colors[i], linestyle="--", marker="o", ) last_time_point = times_patient[-1] last_reconstruction_value = reconstruction_values_patient.flatten()[-1] ax.text( last_time_point, last_reconstruction_value, data.indices[i], color=colors[i], ) min_time, max_time = np.percentile( data.timepoints[data.timepoints > 0.0].cpu().detach().numpy(), [10, 90], ) timepoints_np = np.linspace(min_time, max_time, 100) model_values_np = model.compute_mean_traj( torch.tensor(np.expand_dims(timepoints_np, 0)) ) for feature in range(model.dimension): ax.plot( timepoints_np, model_values_np[0, :, feature], c="gray", linewidth=3, alpha=0.3, ) line_rec = Line2D([0], [0], label="Reconstructions", color="black") line_real = Line2D( [0], [0], label="Real feature values", color="black", linestyle="--" ) line_avg = Line2D([0], [0], label="Global avg. features", color="gray") handles, labels = ax.get_legend_handles_labels() handles.extend([line_rec, line_real, line_avg]) ax.legend(handles=handles) path_iteration = self.path_plot_patients / f"plot_patients_{iteration}.pdf" plt.savefig(path_iteration) plt.close()
def _get_files_related_to_parameters(self, parameters: Iterable[str]) -> list[Path]: """Retrieve the list of file paths related to the given parameters. Parameters ---------- parameters : Iterable[:obj:`str`] A list of parameter names. Returns ------- list[Path] A list of file paths that match the parameter names. """ return [ f for f in self.path_save_model_parameters_convergence.iterdir() if any(f.name.startswith(param) for param in parameters) ] def _extract_parameter_name_and_index( self, parameter_name: str ) -> tuple[Optional[str], Optional[int]]: """Extract the parameter name and its corresponding index (if applicable) from the given parameter name. Parameters ---------- parameter_name : :obj:`str` The name of the parameter to extract information from. Returns ------- tuple[Optional[:obj:`str`], Optional[:obj:`int`]] A tuple where the first element is the parameter name and the second is its index (if applicable). Examples -------- >>> _extract_parameter_name_and_index("mixing_matrix_10") ('mixing_matrix', 10) >>> _extract_parameter_name_and_index("mixing_matrix_1") ('mixing_matrix', 1) >>> _extract_parameter_name_and_index("mixing_matrix_") ('mixing_matrix_', None) >>> _extract_parameter_name_and_index("mixing_matrix_10_foo") ('mixing_matrix_10_foo', None) >>> _extract_parameter_name_and_index("v0") ('v0', None) """ if parameter_name == "v0": return parameter_name, None match = re.search(r"^(.*?)(\d+)$", parameter_name) if match: return match.group(1).strip("_"), int(match.group(2)) else: return parameter_name, None def _set_title_for_parameter( self, ax: plt.Axes, i: int, parameter_name: str, model: McmcSaemCompatibleModel, index: int, ) -> plt.Axes: """Set the title for the plot based on the parameter name and the model's features or other information. Parameters ---------- ax : matplotlib.axes.Axes The axes object of the plot. i : :obj:`int` The index of the current plot. parameter_name : :obj:`str` The name of the parameter being plotted. model : :class:`~leaspy.models.McmcSaemCompatibleModel` The model containing features and event information. index : :obj:`int` The index of the specific feature or event, or None if not applicable. Returns ------- matplotlib.axes.Axes The axes object with the title set. """ if parameter_name == "mixing_matrix": ax[i].set_title(parameter_name + " " + model.features[index]) elif parameter_name == "zeta": ax[i].set_title(parameter_name + " " + "event" + " " + str(index + 1)) elif parameter_name.startswith("sourcewise"): ax[i].set_title( parameter_name.replace("sourcewise_", "") + " " + "source" + " " + str(index) ) else: ax[i].set_title(parameter_name) return ax def _set_legend_for_parameters( self, ax: plt.Axes, i: int, parameter_name: str, params_with_feature_labels: Iterable[str], params_with_sources: Iterable[str], params_with_events: Iterable[str], model: McmcSaemCompatibleModel, ) -> plt.Axes: """Set the legend for the plot based on the parameter name and the model's features, sources, or events. Parameters ---------- ax : matplotlib.axes.Axes The axes object of the plot. i : :obj:`int` The index of the current plot. parameter_name : :obj:`str` The name of the parameter being plotted. params_with_feature_labels : list[:obj:`str`] A list of parameters associated with feature labels. params_with_sources : list[:obj:`str`] A list of parameters associated with sources. params_with_events : list[:obj:`str`] A list of parameters associated with events. model : :class:`~leaspy.models.McmcSaemCompatibleModel` The model containing the necessary information for legends. Returns ------- matplotlib.axes.Axes The axes object with the legend set. """ if parameter_name in params_with_feature_labels: ax[i].legend(model.features, loc="best") if hasattr(model, "source_dimension") and parameter_name in params_with_sources: sources = [ "Source" + " " + str(i + 1) for i in range(model.source_dimension) ] ax[i].legend(sources, loc="best") if hasattr(model, "nb_events"): if parameter_name in params_with_events: events = ["Event" + " " + str(i + 1) for i in range(model.nb_events)] ax[i].legend(events, loc="best") if parameter_name.startswith("sourcewise"): if "mixing_matrix" in parameter_name: ax[i].legend(model.features, loc="best") if "zeta" in parameter_name: events = [ "Event" + " " + str(i + 1) for i in range(model.nb_events) ] ax[i].legend(events, loc="best") return ax def _compute_files_sourcewise( self, params_with_sources: Iterable[str], num_sources: int ) -> list[Path]: """This function processes parameters that have sources and generates new csv files for each source. Parameters ---------- params_with_sources : list[:obj:`str`] A list of parameters that have associated source data. num_sources : :obj:`int` The number of sources in model Returns ------- list[Path] The list of newly created sourcewise files. """ new_files = [] for param_name in params_with_sources: related_files = self._get_files_related_to_parameters([param_name]) if not related_files: continue related_files.sort() for source_idx in range(num_sources): new_files.append( self._concatenate_and_save_sourcewise_parameters( param_name, related_files, source_idx ) ) return new_files def _concatenate_and_save_sourcewise_parameters( self, parameter_name: str, files: Iterable[Path], source_idx: int ) -> Path: """Concatenates data for the specific source and saves it as a csv file. Parameters ---------- parameter_name : :obj:`str` A sourcewise parameter files : Iterable[Path] A list of file paths containing sourcewise parameter source_idx : :obj:`int` The index of the source Returns ------- Path The path to the saved csv file containing data for the specified source """ df = self._concatenate_parameters(files, source_idx) file_name = f"sourcewise_{parameter_name}_{source_idx + 1}.csv" file_path = self.path_save_model_parameters_convergence / file_name df.to_csv(file_path, header=False) return file_path def _concatenate_parameters( self, files: Iterable[Path], source_idx: int ) -> pd.DataFrame: """Extracts and concatenates a specific column from multiple csv files into a single DataFrame Parameters ---------- files : Iterable[Path] A list of file paths source_idx : :obj: `int` The index of the source Returns ------- :obj: `pd.DataFrame` A DataFrame with each column representing the selected source's data from multiple files """ return pd.concat( [ pd.read_csv(file_path, index_col=0, header=None).iloc[:, source_idx] for file_path in files ], axis=1, join="inner", )