API Reference

The Core run() Function

The run() function is the main entry point for all analyses in the library. It is a unified interface that orchestrates data processing, model training, and results aggregation based on the specified mode.

neural_mi.run(x_data: ndarray | Tensor | List, y_data: ndarray | Tensor | List | None = None, mode: str = 'estimate', processor_type: str | None = None, processor_params: Dict[str, Any] | None = None, processor_type_x: str | None = None, processor_params_x: Dict[str, Any] | None = None, processor_type_y: str | None = None, processor_params_y: Dict[str, Any] | None = None, base_params: Dict[str, Any] | None = None, sweep_grid: Dict[str, list] | None = None, output_units: str = 'bits', estimator: str = 'infonce', estimator_params: Dict[str, Any] | None = None, custom_critic: Module | None = None, custom_embedding_cls: type | None = None, save_best_model_path: str | None = None, random_seed: int | None = None, verbose: bool = True, device: str | None = None, split_mode: str = 'blocked', train_indices: ndarray | None = None, test_indices: ndarray | None = None, delta_threshold: float = 0.1, min_gamma_points: int = 5, confidence_level: float = 0.68, **analysis_kwargs) Results[source]

The unified entry point for all analyses in the NeuralMI library.

This function provides a single, consistent interface for various mutual information estimation workflows. It handles data validation, processing, model training, and results aggregation, returning a standardized Results object that can be easily inspected and plotted.

Parameters:
  • x_data (np.ndarray, torch.Tensor, or list) –

    The data for variable X. The required format depends on processor_type:

    • ’continuous’ or ‘categorical’: A 2D array, typically of shape (n_channels, n_timepoints). Data of shape (n_timepoints, n_channels) is also supported and will be transposed automatically.

    • ’spike’: A list of 1D NumPy arrays, where each array contains spike times for a single channel/neuron.

  • y_data (np.ndarray, torch.Tensor, or list, optional) – The data for variable Y. Required for all modes except ‘dimensionality’. Should have the same format as x_data. Defaults to None.

  • mode ({'estimate', 'sweep', 'dimensionality', 'rigorous', 'lag'}, default='estimate') – The analysis mode to run.

  • processor_type_x ({'continuous', 'spike', 'categorical'}, optional) – The type of processing to apply to raw X data. If None, data is assumed to be pre-processed. Defaults to None.

  • processor_params_x (dict, optional) – Parameters for the X data processor, e.g., {'window_size': 10}. Defaults to None.

  • processor_type_y ({'continuous', 'spike', 'categorical'}, optional) – The type of processing to apply to raw Y data. If None, data is assumed to be pre-processed. Defaults to None.

  • processor_params_y (dict, optional) – Parameters for the Y data processor, e.g., {'window_size': 10}. Defaults to None.

  • base_params (dict, optional) – A dictionary of fixed parameters for the MI estimator’s trainer. These are used for all runs. Common parameters include n_epochs, learning_rate, batch_size, embedding_dim, etc. Defaults to {}.

  • sweep_grid (dict, optional) – A dictionary defining the parameter grid for ‘sweep’ and ‘dimensionality’ modes. Keys are parameter names and values are lists of values to test, e.g., {'embedding_dim': [8, 16, 32]}. Defaults to None.

  • output_units ({'bits', 'nats'}, default='bits') – The units for the final MI estimate.

  • estimator ({'infonce', 'nwj', 'tuba', 'smile', 'js'}, default='infonce') – The MI lower bound to use for estimation.

  • estimator_params (dict, optional) – Additional keyword arguments for the selected estimator function. For example, {'clip': 5.0} for the ‘smile’ estimator. Defaults to None.

  • custom_critic (torch.nn.Module, optional) – A pre-initialized custom critic model. If provided, all internal model building is skipped. base_params related to model architecture will be ignored. Defaults to None.

  • custom_embedding_cls (type, optional) – A user-defined embedding model class (not an instance) to be used within the library’s standard critic structures. Defaults to None.

  • save_best_model_path (str, optional) – If provided, the file path where the state dictionary of the best-performing trained critic model will be saved. Defaults to None.

  • random_seed (int, optional) – A seed for random, numpy, and torch to ensure reproducibility. Note: Full reproducibility is only guaranteed for n_workers=1. Defaults to None.

  • verbose (bool, default=True) – If True, progress bars and informational logs will be displayed.

  • device (str, optional) – The compute device to use (e.g., ‘cpu’, ‘cuda’, ‘mps’). If None, it is auto-detected. Defaults to None.

  • split_mode ({'blocked', 'random'}, default='blocked') – Method for splitting data. ‘blocked’ is for time-series, ‘random’ for IID. Ignored if train/test indices are provided.

  • train_indices (np.ndarray, optional) – Specific indices for the training set.

  • test_indices (np.ndarray, optional) – Specific indices for the test set.

  • delta_threshold (float, default=0.1) – For mode='rigorous', the curvature threshold for determining the linear region of the MI vs. 1/gamma plot. Lower values enforce stricter linearity.

  • min_gamma_points (int, default=5) – For mode='rigorous', the minimum number of gamma values required to perform a reliable extrapolation fit after pruning non-linear points.

  • confidence_level (float, default=0.68) – For mode='rigorous', the confidence level (e.g., 0.68 for ~1 std dev) used for the final MI estimate’s error bars.

  • **analysis_kwargs – Additional keyword arguments passed to the specific analysis engine. Common examples include n_workers, n_splits, or gamma_range. For mode='lag', this must include lag_range.

Returns:

A standardized object containing the analysis results, which can be inspected as a dataframe or plotted directly via its .plot() method.

Return type:

neural_mi.results.Results

Examples

Perform a rigorous, bias-corrected MI estimation between two variables.

>>> import neural_mi as nmi
>>> import numpy as np
>>> # Generate synthetic data
>>> x_raw, y_raw = nmi.datasets.generate_nonlinear_from_latent(
...     n_samples=2500, latent_dim=10, observed_dim=100, mi=3.0
... )
>>> # Define model and training parameters
>>> base_params = {
...     'n_epochs': 50, 'learning_rate': 1e-3, 'batch_size': 128,
...     'embedding_dim': 16, 'hidden_dim': 64
... }
>>> # Run the analysis
>>> results = nmi.run(
...     x_data=x_raw.T, y_data=y_raw.T,
...     mode='rigorous',
...     processor_type_x='continuous',
...     processor_params_x={'window_size': 1},
...     base_params=base_params,
...     n_workers=4,
...     random_seed=42
... )
>>> mi_est = results.mi_estimate
>>> mi_err = results.details.get('mi_error', 0.0)
>>> print(f"Corrected MI: {mi_est:.3f} ± {mi_err:.3f} bits")
Corrected MI: 2.953 ± 0.081 bits

The Results Object

All calls to run() return a Results object. This object acts as a container for all the outputs of an analysis, providing convenient access to the final MI estimate, the raw data, and a built-in plotting method.

class neural_mi.results.Results(mode: str, params: ~typing.Dict[str, ~typing.Any] = <factory>, mi_estimate: float | None = None, dataframe: ~pandas.core.frame.DataFrame | None = None, details: ~typing.Dict[str, ~typing.Any] = <factory>)[source]

Bases: object

A data class to store and interact with analysis results.

This class provides a structured way to access the outputs of the run function. Depending on the analysis mode, different attributes will be populated.

mode

The analysis mode that was run (e.g., ‘estimate’, ‘sweep’).

Type:

str

params

A dictionary of the parameters used for the analysis run.

Type:

Dict[str, Any]

mi_estimate

The final point estimate of mutual information. Populated in ‘estimate’ and ‘rigorous’ modes.

Type:

float, optional

dataframe

A DataFrame containing detailed results. Populated in ‘sweep’, ‘dimensionality’, and ‘rigorous’ modes.

Type:

pd.DataFrame, optional

details

A dictionary containing additional metadata or detailed results, such as raw run data or estimated latent dimensions.

Type:

Dict[str, Any]

plot(ax: Axes | None = None, **kwargs) Axes[source]

Visualizes the results of the analysis.

This method dispatches to the appropriate plotting function based on the analysis mode.

  • For ‘sweep’ and ‘dimensionality’ modes, it plots the MI estimate against the swept hyperparameter.

  • For ‘rigorous’ mode, it plots the bias correction fit.

Parameters:
  • ax (plt.Axes, optional) – A matplotlib Axes object to plot on. If None, a new figure and axes are created. Defaults to None.

  • **kwargs (dict) – Additional keyword arguments passed to the underlying plotting function (e.g., figsize, show, title).

Returns:

The matplotlib Axes object containing the plot.

Return type:

plt.Axes

Raises:
  • ValueError – If the Results object does not contain the necessary data (e.g., a DataFrame) to create the plot for the given mode.

  • NotImplementedError – If plotting is not supported for the analysis mode.

Data Generation (datasets)

This module provides functions to generate synthetic datasets with known properties. These are useful for testing estimators, validating models, and following the tutorials.

This package contains functions for generating synthetic datasets.

These functions are useful for testing, validating, and demonstrating the capabilities of the mutual information estimators in the library.

neural_mi.datasets.generate_correlated_gaussians(n_samples: int, dim: int, mi: float, use_torch: bool = True) Tuple[ndarray, ndarray][source]

Generates two correlated multivariate Gaussian datasets.

The ground truth mutual information between these two variables can be calculated analytically.

Parameters:
  • n_samples (int) – The number of samples to generate.

  • dim (int) – The number of dimensions for each variable.

  • mi (float) – The ground truth mutual information in bits.

  • use_torch (bool, optional) – If True, returns torch.Tensors; otherwise, returns NumPy arrays. Defaults to True.

Returns:

A tuple containing: - x (np.ndarray or torch.Tensor): The first dataset, of shape (n_samples, dim). - y (np.ndarray or torch.Tensor): The second dataset, of shape (n_samples, dim).

Return type:

Tuple[np.ndarray, np.ndarray]

neural_mi.datasets.generate_nonlinear_from_latent(n_samples: int, latent_dim: int, observed_dim: int, mi: float, hidden_dim: int = 64, use_torch: bool = True) Tuple[ndarray, ndarray][source]

Generates two nonlinearly related datasets from a shared latent variable.

A low-dimensional latent variable z is first generated. Two observed variables, x and y, are then created as nonlinear projections of z with added noise.

Parameters:
  • n_samples (int) – The number of samples to generate.

  • latent_dim (int) – The dimensionality of the shared latent variable z.

  • observed_dim (int) – The dimensionality of the observed variables x and y.

  • mi (float) – The ground truth MI between the latent variables Z_x and Z_y in bits.

  • hidden_dim (int, optional) – The hidden dimension of the transforming MLPs. Defaults to 64.

  • use_torch (bool, optional) – If True, returns torch.Tensors. Defaults to True.

Returns:

A tuple containing: - x (np.ndarray or torch.Tensor): The first dataset, of shape (n_samples, observed_dim). - y (np.ndarray or torch.Tensor): The second dataset, of shape (n_samples, observed_dim).

Return type:

Tuple[np.ndarray, np.ndarray]

neural_mi.datasets.generate_temporally_convolved_data(n_samples, lag=30, noise=0.1, use_torch=True)[source]

Generates data where Y is a simple time-delayed version of X.

This creates a clean, unambiguous temporal relationship ideal for testing the windowing functionality of the MI estimator.

Parameters:
  • n_samples (int) – The number of time points to generate.

  • lag (int) – The number of timepoints to delay Y relative to X.

  • noise (float) – The amount of Gaussian noise to add to Y.

  • use_torch (bool) – If True, returns torch.Tensors.

Returns:

A tuple (x, y) of the generated temporal data, each of shape

[1, n_samples] for compatibility with our processors.

Return type:

tuple

neural_mi.datasets.generate_xor_data(n_samples: int, noise: float = 0.1, use_torch: bool = True) Tuple[ndarray, ndarray][source]

Generates data for the XOR task, a classic test for synergy.

The XOR problem is a classic example where the mutual information between the joint variable (x1, x2) and y is high, but the MI between either x1 and y or x2 and y individually is zero.

Parameters:
  • n_samples (int) – Number of samples.

  • noise (float, optional) – Noise to add to the continuous Y variable. Defaults to 0.1.

  • use_torch (bool, optional) – If True, returns torch.Tensors. Defaults to True.

Returns:

A tuple (x, y) where x is (n_samples, 2) and y is (n_samples, 1).

Return type:

Tuple[np.ndarray, np.ndarray]

Generates data with a sparse event signal (X) and a delayed response (Y).

This dataset is designed to test lag detection. Signal X contains a few sharp spikes (“events”). Signal Y is zero everywhere except for a stereotyped response (a sine wave) that begins lag time steps after each event in X.

This structure prevents the model from using global statistical features and forces it to learn the precise local temporal relationship.

Parameters:
  • n_samples (int) – The number of time points to generate.

  • lag (int) – The number of timepoints to delay Y’s response relative to X.

  • n_events (int) – The number of sparse events in signal X.

  • response_length (int) – The duration of the sine wave response in Y.

  • noise (float) – The amount of Gaussian noise to add to Y.

  • use_torch (bool) – If True, returns torch.Tensors.

Returns:

A tuple (x, y) of the generated temporal data, each of shape

[1, n_samples].

Return type:

tuple

neural_mi.datasets.generate_linear_data(n_samples=5000, true_lag=50, noise_level=0.5)[source]

Y(t) is a simple, lagged version of X(t).

neural_mi.datasets.generate_linear_data(n_samples=5000, true_lag=50, noise_level=0.5)[source]

Y(t) is a simple, lagged version of X(t).

neural_mi.datasets.generate_nonlinear_data(n_samples=5000, true_lag=50, noise_level=0.2)[source]

Y(t) is a nonlinear function of lagged X(t).

neural_mi.datasets.generate_nonlinear_data(n_samples=5000, true_lag=50, noise_level=0.2)[source]

Y(t) is a nonlinear function of lagged X(t).

neural_mi.datasets.generate_history_data(n_samples=5000, history_duration=20, noise_level=0.1)[source]

Y(t) is a nonlinear function of the moving average of X over a recent window (no lag).

neural_mi.datasets.generate_full_data(n_samples=5000, true_lag=30, history_duration=20, noise_level=0.3)[source]

Y(t) is a nonlinear function of the moving average of X over a past, lagged window.

Visualization (visualize)

This module contains helper functions for creating publication-quality plots of analysis results. These functions are typically called automatically by the Results.plot() method but can also be used directly.

This package contains modules for visualizing analysis results.

neural_mi.visualize.plot_sweep_curve(summary_df: DataFrame, param_col: str, mean_col: str = 'mi_mean', std_col: str = 'mi_std', true_value: float | None = None, estimated_values: Any | None = None, ax: Axes | None = None, units: str = 'bits', **kwargs)[source]

Plots the results of a hyperparameter sweep.

This function creates a curve of the mean MI estimate against the values of the swept hyperparameter, with a shaded region representing the standard deviation. It can also display true and estimated values as vertical lines for comparison.

Parameters:
  • summary_df (pd.DataFrame) – A DataFrame containing the summarized results of the sweep. Must contain columns for the parameter, mean MI, and std dev of MI.

  • param_col (str) – The name of the column in summary_df that contains the swept hyperparameter values.

  • mean_col (str, optional) – The name of the column for the mean MI estimate. Defaults to ‘mi_mean’.

  • std_col (str, optional) – The name of the column for the standard deviation of the MI estimate. Defaults to ‘mi_std’.

  • true_value (float, optional) – If known, the true value of the parameter, to be plotted as a vertical dashed line. Defaults to None.

  • estimated_values (Any, optional) – An estimated value or a dictionary of estimated values to plot as vertical dotted lines. Defaults to None.

  • ax (plt.Axes, optional) – A matplotlib Axes object to plot on. If None, a new figure and axes are created. Defaults to None.

  • units (str, optional) – The units of the MI estimate (e.g., ‘bits’ or ‘nats’) for axis labels. Defaults to ‘bits’.

  • **kwargs (dict) – Additional keyword arguments passed to ax.plot.

neural_mi.visualize.plot_bias_correction_fit(raw_results_df: DataFrame, corrected_result: Dict[str, Any], ax: Axes | None = None, units: str = 'bits', **kwargs)[source]

Plots the results of a rigorous, bias-corrected analysis.

This function visualizes the extrapolation fit used for bias correction. It shows the raw MI estimates for each data subset size (gamma), the mean MI at each gamma, and the final linear fit extrapolated to an infinite dataset size (gamma=0).

Parameters:
  • raw_results_df (pd.DataFrame) – A DataFrame containing the raw results from all training runs in the rigorous analysis. Must contain ‘gamma’ and ‘test_mi’ columns.

  • corrected_result (Dict[str, Any]) – A dictionary containing the results of the bias correction, including the ‘slope’, ‘mi_corrected’, ‘mi_error’, and ‘gammas_used’.

  • ax (plt.Axes, optional) – A matplotlib Axes object to plot on. If None, a new figure and axes are created. Defaults to None.

  • units (str, optional) – The units of the MI estimate (e.g., ‘bits’ or ‘nats’) for labels. Defaults to ‘bits’.

neural_mi.visualize.plot_cross_correlation(x, y, true_lag)[source]

Plotting function for cross-correlation.

neural_mi.visualize.analyze_mi_heatmap(results_df, absolute_mi_threshold=0.2, contour_rise_fraction=0.1, radius_multiplier=1.2, true_lag=None, history_duration=None)[source]

Performs the ultimate topological analysis of a 2D MI heatmap. - Finds the Causal Contour. - Finds the shortest “bridge” between the Causal and Significant MI Contours. - Draws a “Parsimonious Circle” centered on this bridge to highlight the optimal region.

Parameters:
  • results_df (pd.DataFrame) – DataFrame with ‘lag’, ‘window_size’, and ‘mi’ columns.

  • absolute_mi_threshold (float) – The absolute MI value to consider “significant”.

  • contour_rise_fraction (float) – Heuristic for finding the Causal Contour rise point.

  • radius_multiplier (float) – Factor to scale the Parsimonious Circle’s radius.

  • true_lag (float, optional) – The true lag value to mark on the plot.

  • history_duration (float, optional) – The true history/window duration to mark on the plot.

Configuration

Use this function to control the library’s logging output level.

neural_mi.logger.set_verbosity(level: int | str)[source]

Sets the global verbosity level for the library’s logger.

This function provides a simple way to control the logging output of the entire library.

Parameters:

level (int or str) – The desired verbosity level. Can be an integer from 0 (CRITICAL) to 4 (DEBUG), or a string (‘CRITICAL’, ‘ERROR’, ‘WARNING’, ‘INFO’, ‘DEBUG’).

Exceptions

These are the custom exceptions raised by the library to signal specific errors.

Defines custom exceptions for the neural_mi library.

Using custom exceptions allows for more specific error handling and clearer error messages, making the library easier to debug and use.

exception neural_mi.exceptions.NeuralMIError[source]

Base class for all custom exceptions in the neural_mi library.

exception neural_mi.exceptions.DataShapeError[source]

Exception raised for errors related to the shape of input data.

This is typically raised when an input tensor or array does not have the expected number of dimensions or when dimensions have an incorrect size.

exception neural_mi.exceptions.InsufficientDataError[source]

Exception raised when not enough data is provided for an operation.

This is a subclass of DataShapeError and is used, for example, when the length of a time series is smaller than the required window size for processing.

exception neural_mi.exceptions.TrainingError[source]

Exception raised for critical errors that occur during model training.

This exception is used to signal that the training process has failed and cannot continue, for example, if no valid model checkpoint could be created.