# neural_mi/visualize/plot.py
"""Provides plotting functions for visualizing analysis results.
This module contains functions to generate plots for different analysis modes,
such as hyperparameter sweeps and bias correction fits. These are typically
called via the `.plot()` method of the `Results` object.
"""
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import pandas as pd
import numpy as np
from scipy.signal import correlate
from scipy.stats import zscore
from matplotlib.ticker import MaxNLocator
from typing import Optional, Dict, Any
from scipy.spatial.distance import cdist
from matplotlib.lines import Line2D
def set_publication_style():
"""Applies a professional, publication-ready style to matplotlib plots.
This function updates matplotlib's rcParams to create plots with a
serif font (Times New Roman), appropriate font sizes for labels and
titles, and a clean layout suitable for academic papers or reports.
"""
plt.rcParams.update({
"font.family": "serif", "font.serif": "Times New Roman", "mathtext.fontset": "cm",
'figure.dpi': 100, 'font.size': 16, 'axes.titlesize': 18, 'axes.labelsize': 16,
'xtick.labelsize': 15, 'ytick.labelsize': 15, 'legend.fontsize': 14
})
[docs]
def plot_sweep_curve(summary_df: pd.DataFrame, param_col: str, mean_col: str = 'mi_mean',
std_col: str = 'mi_std', true_value: Optional[float] = None,
estimated_values: Optional[Any] = None, ax: Optional[plt.Axes] = None,
units: str = 'bits', **kwargs):
"""Plots the results of a hyperparameter sweep.
This function creates a curve of the mean MI estimate against the values
of the swept hyperparameter, with a shaded region representing the
standard deviation. It can also display true and estimated values as
vertical lines for comparison.
Parameters
----------
summary_df : pd.DataFrame
A DataFrame containing the summarized results of the sweep. Must
contain columns for the parameter, mean MI, and std dev of MI.
param_col : str
The name of the column in `summary_df` that contains the swept
hyperparameter values.
mean_col : str, optional
The name of the column for the mean MI estimate. Defaults to 'mi_mean'.
std_col : str, optional
The name of the column for the standard deviation of the MI estimate.
Defaults to 'mi_std'.
true_value : float, optional
If known, the true value of the parameter, to be plotted as a
vertical dashed line. Defaults to None.
estimated_values : Any, optional
An estimated value or a dictionary of estimated values to plot as
vertical dotted lines. Defaults to None.
ax : plt.Axes, optional
A matplotlib Axes object to plot on. If None, a new figure and axes
are created. Defaults to None.
units : str, optional
The units of the MI estimate (e.g., 'bits' or 'nats') for axis labels.
Defaults to 'bits'.
**kwargs : dict
Additional keyword arguments passed to `ax.plot`.
"""
if ax is None: fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(summary_df[param_col], summary_df[mean_col], 'o-', label='Mean MI', **kwargs)
ax.fill_between(summary_df[param_col], summary_df[mean_col] - summary_df[std_col],
summary_df[mean_col] + summary_df[std_col], alpha=0.2, label='±1 Std Dev')
if true_value is not None:
ax.axvline(x=true_value, color='r', linestyle='--', label=f'True Value = {true_value}')
if isinstance(estimated_values, dict):
colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(estimated_values)))
for i, (prefix, val) in enumerate(estimated_values.items()):
ax.axvline(x=val, color=colors[i], linestyle=':', linewidth=3, label=f'Est. ({prefix}) = {val}')
elif estimated_values is not None:
ax.axvline(x=estimated_values, color='g', linestyle=':', linewidth=3, label=f'Estimated = {estimated_values}')
ax.set_xlabel(param_col.replace('_', ' ').title()); ax.set_ylabel(f"MI ({units})")
ax.set_title(f"MI vs. {param_col.replace('_', ' ').title()}"); ax.legend()
ax.grid(True, linestyle=':'); sns.despine(ax=ax)
if pd.api.types.is_numeric_dtype(summary_df[param_col]) and all(summary_df[param_col] == np.floor(summary_df[param_col])):
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
if 'fig' in locals(): plt.tight_layout()
return ax
[docs]
def plot_bias_correction_fit(raw_results_df: pd.DataFrame, corrected_result: Dict[str, Any],
ax: Optional[plt.Axes] = None, units: str = 'bits', **kwargs):
"""Plots the results of a rigorous, bias-corrected analysis.
This function visualizes the extrapolation fit used for bias correction.
It shows the raw MI estimates for each data subset size (gamma), the mean
MI at each gamma, and the final linear fit extrapolated to an infinite
dataset size (gamma=0).
Parameters
----------
raw_results_df : pd.DataFrame
A DataFrame containing the raw results from all training runs in the
rigorous analysis. Must contain 'gamma' and 'test_mi' columns.
corrected_result : Dict[str, Any]
A dictionary containing the results of the bias correction, including
the 'slope', 'mi_corrected', 'mi_error', and 'gammas_used'.
ax : plt.Axes, optional
A matplotlib Axes object to plot on. If None, a new figure and axes
are created. Defaults to None.
units : str, optional
The units of the MI estimate (e.g., 'bits' or 'nats') for labels.
Defaults to 'bits'.
"""
if ax is None: fig, ax = plt.subplots(1, 1, figsize=(8, 6))
sns.stripplot(x='gamma', y='test_mi', data=raw_results_df, ax=ax, color='gray', alpha=0.5)
agg = raw_results_df.groupby('gamma')['test_mi'].mean().reset_index()
ax.plot(agg['gamma'] - 1, agg['test_mi'], 'o-', color='black', label='Mean MI per Gamma')
slope, intercept = corrected_result['slope'], corrected_result['mi_corrected']
mi_error, gammas_used = corrected_result.get('mi_error', 0), corrected_result['gammas_used']
fit_x = np.array([0] + gammas_used)
ax.plot(fit_x - 1, slope * fit_x + intercept, 'r--', linewidth=2, label='WLS Extrapolation')
ax.errorbar(x=-1, y=intercept, yerr=mi_error, fmt='r*', markersize=15, capsize=5,
label=f'Corrected MI = {intercept:.2f} ± {mi_error:.2f} {units}')
ax.set_xticks(np.unique(raw_results_df['gamma']) - 1)
ax.set_xticklabels(np.unique(raw_results_df['gamma']))
ax.set_xlabel(r"Number of Subsets ($\gamma$)"); ax.set_ylabel(f"MI Estimate ({units})")
ax.set_title("Bias Correction via Extrapolation"); ax.legend()
ax.grid(True, linestyle=':'); sns.despine(ax=ax)
if 'fig' in locals():
plt.tight_layout()
[docs]
def plot_cross_correlation(x, y, true_lag):
"""Plotting function for cross-correlation."""
lags = np.arange(-len(x[0]) // 2 + 1, len(x[0]) // 2 + 1)
corr = correlate(zscore(y[0]), zscore(x[0]), mode='same') / len(x[0])
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(lags, corr)
ax.axvline(true_lag + 1, color='r', linestyle='-.', label=f'True Lag ({true_lag})')
ax.axvline(lags[np.argmax(corr)], color='g', linestyle=':', label=f'Found Lag ({lags[np.argmax(corr)]})')
ax.set_xlabel('Lag')
ax.set_ylabel('Cross-Correlation')
ax.set_title('Linear Correlation vs Lag')
ax.set_xlim(-100, 100)
ax.legend()
plt.show()
[docs]
def analyze_mi_heatmap(results_df, absolute_mi_threshold=0.2, contour_rise_fraction=0.1, radius_multiplier=1.2,
true_lag=None, history_duration=None):
"""
Performs the ultimate topological analysis of a 2D MI heatmap.
- Finds the Causal Contour.
- Finds the shortest "bridge" between the Causal and Significant MI Contours.
- Draws a "Parsimonious Circle" centered on this bridge to highlight the optimal region.
Args:
results_df (pd.DataFrame): DataFrame with 'lag', 'window_size', and 'mi' columns.
absolute_mi_threshold (float): The absolute MI value to consider "significant".
contour_rise_fraction (float): Heuristic for finding the Causal Contour rise point.
radius_multiplier (float): Factor to scale the Parsimonious Circle's radius.
true_lag (float, optional): The true lag value to mark on the plot.
history_duration (float, optional): The true history/window duration to mark on the plot.
"""
# --- 1. Data Preparation ---
heatmap_data = results_df.pivot(index='window_size', columns='lag', values='mi')
lags = heatmap_data.columns.values
windows = heatmap_data.index.values
# --- 2. Causal Contour Analysis ---
causal_contour_c = None
if 0 in lags:
lag0_data = heatmap_data[0]
noise_floor = lag0_data.iloc[:3].median()
peak_mi = lag0_data.max()
rise_threshold = noise_floor + (peak_mi - noise_floor) * contour_rise_fraction
significant_windows = lag0_data[lag0_data > rise_threshold]
if not significant_windows.empty:
causal_contour_c = significant_windows.index[0]
print(f"--- Causal Contour Analysis ---")
print(f"MI at lag=0 rises at window_size = {causal_contour_c} (implies lag_true + history_true ≈ {causal_contour_c})\n")
else:
print("--- Causal Contour Analysis ---\nLag=0 not found. Skipping Causal Contour estimation.\n")
# --- 3. Create the main figure for all analysis ---
fig, ax = plt.subplots(figsize=(11, 8))
# Use pcolormesh instead of seaborn heatmap for consistent coordinates
# Need to create mesh grid edges for pcolormesh
lag_edges = np.concatenate([lags - (lags[1] - lags[0])/2, [lags[-1] + (lags[1] - lags[0])/2]])
window_edges = np.concatenate([windows - (windows[1] - windows[0])/2, [windows[-1] + (windows[1] - windows[0])/2]])
mesh = ax.pcolormesh(lag_edges, window_edges, heatmap_data.values, cmap='viridis', shading='flat')
cbar = plt.colorbar(mesh, ax=ax, label='Mutual Information')
# --- 4. Significant Zone & Parsimony Analysis ---
print(f"--- Parsimony Analysis (Significant MI > {absolute_mi_threshold}) ---")
# Create contour on the same axes
cs = ax.contour(lags, windows, heatmap_data.values, levels=[absolute_mi_threshold],
colors='red', linewidths=2.5, linestyles='-')
if not cs.allsegs[0]:
print("Warning: No significant MI contour found. Try a lower threshold.")
ax.set_title('Parsimony-Informed Topological Analysis (No Significant Contour Found)')
ax.set_xlabel('Lag (Timepoints)')
ax.set_ylabel('Window Size (Timepoints)')
plt.show()
return
# Extract the largest continuous contour segment
significant_contour_points = np.array(max(cs.allsegs[0], key=len))
midpoint, radius = None, None
if causal_contour_c is not None:
# Define the Causal Contour line *only within the plot's window range*
causal_lags = lags[(causal_contour_c - lags >= windows.min()) & (causal_contour_c - lags <= windows.max())]
causal_contour_line = np.array([[lg, causal_contour_c - lg] for lg in causal_lags])
if causal_contour_line.size > 0:
# Draw the Causal Contour line
ax.plot(causal_contour_line[:, 0], causal_contour_line[:, 1],
color='cyan', linestyle='--', linewidth=3, label=f'Causal Contour (C≈{causal_contour_c})')
# Find the shortest distance between the two contours
distances = cdist(significant_contour_points, causal_contour_line)
min_dist_idx = np.unravel_index(np.argmin(distances), distances.shape)
point_on_mi_contour = significant_contour_points[min_dist_idx[0]]
point_on_causal_contour = causal_contour_line[min_dist_idx[1]]
midpoint = (point_on_mi_contour + point_on_causal_contour) / 2
bridge_length = np.linalg.norm(point_on_mi_contour - point_on_causal_contour)
radius = (bridge_length / 2) * radius_multiplier
print(f"Shortest bridge is between {point_on_causal_contour} on Causal Contour")
print(f"and {point_on_mi_contour} on Significant MI Contour.")
print(f"Bridge length: {bridge_length:.2f}")
print(f"Parsimonious Center: (lag={midpoint[0]:.1f}, window={midpoint[1]:.1f})")
print(f"Parsimonious Radius: {radius:.2f}")
# Draw the bridge line
ax.plot([point_on_causal_contour[0], point_on_mi_contour[0]],
[point_on_causal_contour[1], point_on_mi_contour[1]],
'orange', linewidth=2, linestyle='-', alpha=0.7)
# Draw the Parsimonious Circle
circle = patches.Circle(midpoint, radius, linewidth=2.5, edgecolor='white',
facecolor='none', linestyle=':', label='Parsimonious Region')
ax.add_patch(circle)
# Mark the center
ax.plot(midpoint[0], midpoint[1], 'w+', markersize=15, mew=3,
label='Parsimonious Center')
# --- 5. Mark True Parameter Box ---
if true_lag is not None and history_duration is not None:
# Calculate the box edges (half a step in each direction)
lag_step = lags[1] - lags[0] if len(lags) > 1 else 1
window_step = windows[1] - windows[0] if len(windows) > 1 else 1
# Create rectangle centered on the true values
true_rect = patches.Rectangle(
(true_lag - lag_step/2, history_duration - window_step/2),
lag_step, window_step,
linewidth=3, edgecolor='lime', facecolor='none', linestyle='-'
)
ax.add_patch(true_rect)
print(f"\n--- True Parameters ---")
print(f"True lag: {true_lag}, True history: {history_duration}")
# Manually create legend handles for a clean legend
legend_elements = []
if causal_contour_c is not None:
legend_elements.append(Line2D([0], [0], color='cyan', lw=3, ls='--',
label=f'Causal Contour (C≈{causal_contour_c})'))
legend_elements.append(Line2D([0], [0], color='red', lw=2.5,
label=f'Significant MI Contour (>{absolute_mi_threshold})'))
if midpoint is not None:
legend_elements.append(Line2D([0], [0], color='orange', lw=2,
label='Bridge (shortest distance)', alpha=0.7))
legend_elements.append(Line2D([0], [0], color='white', lw=2.5, ls=':',
label='Parsimonious Region'))
legend_elements.append(Line2D([0], [0], marker='+', color='w',
label='Parsimonious Center', ls='none', mew=3, markersize=12))
if true_lag is not None and history_duration is not None:
legend_elements.append(patches.Rectangle((0, 0), 1, 1, linewidth=3,
edgecolor='lime', facecolor='none',
label='True Parameters'))
ax.set_title('Parsimony-Informed Topological Analysis')
ax.set_xlabel('Lag (Timepoints)')
ax.set_ylabel('Window Size (Timepoints)')
ax.legend(handles=legend_elements, loc='upper left')
# Fix axis limits to heatmap data range to prevent rescaling
ax.set_xlim(lags.min(), lags.max())
ax.set_ylim(windows.min(), windows.max())
# Show all tick values if we have fewer than 20 of them
if len(lags) < 25:
ax.set_xticks(lags)
if len(windows) < 25:
ax.set_yticks(windows)
plt.tight_layout()
plt.show()