Source code for neural_mi.results

# neural_mi/results.py
"""Defines the `Results` class for storing and interacting with analysis outcomes.

This module provides a standardized data structure for holding the results of
different analysis modes from the `run` function. The `Results` class acts as
a container for MI estimates, dataframes, and detailed metadata, and also
provides a convenient `.plot()` method for visualizing the results.
"""
from dataclasses import dataclass, field
from typing import Optional, Any, Dict
import pandas as pd
import matplotlib.pyplot as plt
from neural_mi.logger import logger

[docs] @dataclass class Results: """A data class to store and interact with analysis results. This class provides a structured way to access the outputs of the `run` function. Depending on the analysis `mode`, different attributes will be populated. Attributes ---------- mode : str The analysis mode that was run (e.g., 'estimate', 'sweep'). params : Dict[str, Any] A dictionary of the parameters used for the analysis run. mi_estimate : float, optional The final point estimate of mutual information. Populated in 'estimate' and 'rigorous' modes. dataframe : pd.DataFrame, optional A DataFrame containing detailed results. Populated in 'sweep', 'dimensionality', and 'rigorous' modes. details : Dict[str, Any] A dictionary containing additional metadata or detailed results, such as raw run data or estimated latent dimensions. """ mode: str params: Dict[str, Any] = field(default_factory=dict) mi_estimate: Optional[float] = None dataframe: Optional[pd.DataFrame] = None details: Dict[str, Any] = field(default_factory=dict) def __repr__(self) -> str: """Provides a concise representation of the Results object.""" rep = f"Results(mode='{self.mode}'" if self.mi_estimate is not None: rep += f", mi_estimate={self.mi_estimate:.4f}" if self.dataframe is not None: rep += f", dataframe_shape={self.dataframe.shape}" if self.details: rep += f", details_keys={list(self.details.keys())}" return rep + ")"
[docs] def plot(self, ax: Optional[plt.Axes] = None, **kwargs) -> plt.Axes: """Visualizes the results of the analysis. This method dispatches to the appropriate plotting function based on the analysis `mode`. - For 'sweep' and 'dimensionality' modes, it plots the MI estimate against the swept hyperparameter. - For 'rigorous' mode, it plots the bias correction fit. Parameters ---------- ax : plt.Axes, optional A matplotlib Axes object to plot on. If None, a new figure and axes are created. Defaults to None. **kwargs : dict Additional keyword arguments passed to the underlying plotting function (e.g., `figsize`, `show`, `title`). Returns ------- plt.Axes The matplotlib Axes object containing the plot. Raises ------ ValueError If the Results object does not contain the necessary data (e.g., a DataFrame) to create the plot for the given mode. NotImplementedError If plotting is not supported for the analysis mode. """ from neural_mi.visualize.plot import plot_sweep_curve, plot_bias_correction_fit show = kwargs.pop('show', True) if ax is None: fig, ax = plt.subplots(1, 1, figsize=kwargs.pop('figsize', (10, 6))) if self.mode in ['sweep', 'dimensionality', 'lag']: if self.dataframe is None: raise ValueError("Cannot plot: results do not contain a DataFrame.") sweep_var = self.params.get('sweep_var', 'embedding_dim' if self.mode == 'dimensionality' else None) if not sweep_var: possible = [c for c in self.dataframe.columns if c not in ['mi_mean', 'mi_std']] if len(possible) == 1: sweep_var = possible[0] logger.warning(f"Inferring sweep_var='{sweep_var}' from DataFrame.") else: raise ValueError(f"Cannot determine sweep variable from {possible}.") plot_sweep_curve(self.dataframe, param_col=sweep_var, ax=ax, **kwargs) elif self.mode == 'rigorous': if self.dataframe is None or not self.details: raise ValueError("Rigorous results are incomplete and cannot be plotted.") plot_bias_correction_fit(self.dataframe, self.details, ax=ax, **kwargs) else: raise NotImplementedError(f"Plotting is not implemented for mode: '{self.mode}'") if show: plt.tight_layout() plt.show() return ax