# Under the Hood: Core Concepts for MI Estimation This document is for the curious user who has completed the tutorials and wants a deeper understanding of what `NeuralMI` is doing under the hood. We won't cover all the advanced features, but we will build a simple neural MI estimator from scratch to demystify the core concepts. The document is based heavily on [this paper](https://arxiv.org/abs/2506.00330). Our goal is to answer three key questions: 1. What *is* a neural MI estimator, really? 2. How is it trained and evaluated? 3. What is the intuition behind the "rigorous" bias correction? Let's dive in with some simple PyTorch code. --- ## Part 1: Anatomy of a Neural MI Estimator At its heart, one way of looking at a neural MI estimator is just a clever way of training a neural network to solve a classification problem. Instead of estimating probability densities directly, we train a **critic** network, `f(x, y)`, to distinguish between "positive" samples (pairs `(x_i, y_i)` that genuinely occurred together) and "negative" samples (pairs `(x_i, y_j)` that did not). Let's build the three essential components from scratch. ### The Components **1. Embedding Networks (`g` and `h`):** These are two simple neural networks that learn to extract meaningful features from `X` and `Y`. ```python import torch import torch.nn as nn # A simple MLP to process an input vector into an embedding def create_embedding_net(input_dim, embedding_dim): return nn.Sequential( nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, embedding_dim) ) ``` **2. The Critic `f(x, y)`:** In a `SeparableCritic`, the "critic" is just the dot product between the embeddings. It computes a similarity score for every possible pairing of samples in a batch. ```python def separable_critic(x_embedded, y_embedded): # x_embedded has shape (batch_size, embedding_dim) # y_embedded has shape (batch_size, embedding_dim) # The result is a (batch_size, batch_size) matrix of scores return torch.matmul(x_embedded, y_embedded.t()) ``` **3. The Estimator (Loss Function):** This is the mathematical formula that turns the score matrix from the critic into an MI estimate. Let's implement the most common one, **InfoNCE**. The formula is: $$ I(X;Y) \ge \mathbb{E}\left[ \frac{1}{N}\sum_{i=1}^N \left( f(x_i,y_i) - \log\left(\sum_{j=1}^N e^{f(x_i,y_j)}\right) \right) \right] + \log(N) $$ This looks complex, but it's just a form of cross-entropy loss. For each `x_i` in the batch (each row of the score matrix), we're trying to maximize the score of its true partner `y_i` (the diagonal element) relative to all other `y_j`'s in the batch (the off-diagonal elements). ```python def infonce_estimator(scores): # scores is the (batch_size, batch_size) matrix from the critic batch_size = scores.shape[0] # The f(x_i, y_i) term is the diagonal of the score matrix positive_scores = torch.diag(scores) # The log-sum-exp term is calculated for each row log_sum_exp = torch.logsumexp(scores, dim=1) # The MI is the mean difference, plus log(batch_size) mi_estimate_nats = torch.mean(positive_scores - log_sum_exp) + torch.log(torch.tensor(batch_size)) return mi_estimate_nats ``` And that's it! A neural MI estimator is just these three pieces working together. --- ## Part 2: The Training Loop Demystified Now, how do we use these components? We train them just like any other neural network: by minimizing a loss function. For MI estimation, the loss is simply the **negative of the MI estimate**. Maximizing the MI is the same as minimizing `-MI`. Here's a simplified training loop: ```python # --- Setup --- dim = 5 embedding_dim = 16 batch_size = 128 n_epochs = 10 # Create simple correlated data x_data, y_data = torch.randn(1000, dim), torch.randn(1000, dim) # Create our embedding networks g_net = create_embedding_net(dim, embedding_dim) h_net = create_embedding_net(dim, embedding_dim) # Group parameters and create an optimizer params = list(g_net.parameters()) + list(h_net.parameters()) optimizer = torch.optim.Adam(params, lr=1e-3) # --- Training Loop --- for epoch in range(n_epochs): # In a real scenario, we would use a DataLoader to get batches x_batch = x_data[:batch_size] y_batch = y_data[:batch_size] # 1. Get embeddings x_embedded = g_net(x_batch) y_embedded = h_net(y_batch) # 2. Get scores from the critic scores = separable_critic(x_embedded, y_embedded) # 3. Calculate the MI estimate mi_estimate = infonce_estimator(scores) # 4. The loss is the negative MI loss = -mi_estimate # 5. Backpropagate and optimize optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 2 == 0: print(f"Epoch {epoch}, MI Estimate (nats): {mi_estimate.item():.3f}") ``` ### Evaluation and Early Stopping During a real training run, we would split our data into training and validation sets. After each epoch, we'd calculate the MI on the validation set. This produces a `test_mi_history` curve. A heuristic introduced in the paper is to stop training and use the model that achieved the highest MI on the validation set. However, this curve can be very noisy. `NeuralMI` follows the same procedure as the paper and improves on this by applying a median filter followed by a Gaussian filter to get a smoothed curve. It then stops training when this smoothed curve has stopped improving for a set number of epochs (`patience`), which is a more robust strategy. ### Data Splitting Strategy A critical step in the training process is splitting the data into training and validation sets. `NeuralMI` uses a robust, hierarchical strategy to ensure the split is appropriate for the data type, preventing common pitfalls like data leakage in time-series. The splitting logic inside the `Trainer` follows this priority order: 1. **User-Provided Indices**: If `train_indices` and `test_indices` are passed to `nmi.run()`, they are used directly. This provides maximum user control and overrides all other settings. 2. **`split_mode` Parameter**: If custom indices are not provided, the trainer looks at the `split_mode` argument: * `split_mode='random'`: This mode is for **Independent and Identically Distributed (IID) data**. It performs a standard random shuffle of all data points before creating the train/validation split. This is the correct choice when there is no temporal relationship between samples. * `split_mode='blocked'` (Default): This mode is for **temporal or sequential data**. Instead of shuffling, it samples several non-overlapping, contiguous blocks of data to form the validation set. This ensures that the validation data is truly "out of sample" in a temporal sense, preventing the model from being tested on points immediately adjacent to ones it was trained on. 3. **Default Behavior**: If no splitting options are specified, the system defaults to `split_mode='blocked'`, as it is the safer and more robust option for the types of physics and neuroscience data the library often handles. --- ## Part 3: The Intuition Behind `mode='rigorous'` Even with a perfectly trained model, any MI estimate from a finite dataset will be biased. The model will inevitably find spurious correlations in the noise, leading to a **systematic overestimation** of the true MI. As explained in the literature, this bias has a predictable relationship with the number of samples, `N`: $$ I_{\text{estimated}}(N) \approx I_{\text{true}} + \frac{a}{N} $$ This means the estimated MI is approximately linear in `1/N`. The `rigorous` mode exploits this relationship to correct for the bias. ### The Extrapolation Procedure 1. **Subsample:** The library runs the MI estimation multiple times on different fractions of the data. For example, it might split the data into `γ=2` halves, then `γ=3` thirds, and so on. This gives us MI estimates for different effective sample sizes `N/γ`. 2. **Plot vs. 1/N:** The library plots the mean MI estimate for each `γ` against `1/(N/γ)`, which is proportional to `γ`. Because of the formula above, this plot should be a straight line. 3. **Extrapolate:** `NeuralMI` performs a weighted linear regression on this line and finds the y-intercept. This intercept corresponds to the point where `1/N = 0`, which represents an infinite dataset. This extrapolated value is the final, bias-corrected MI estimate. 4. **Linearity Check:** If the fit is not *linear* (judged by fitting a second-order weighted least squares first and checking if the ratio of the quadratic to linear contribution `δ` is greater than 10%), we reject this fitting point. We drop the estimates corresponding to this `γ` value and recalculate. If `δ ≥ threshold`, we keep dropping points until we reach `γ < min_gamma_points` (usually 5). If the fit is still not linear, we deem it *not reliable* and warn that it shouldn't be trusted as we don't have enough data. The plot generated by `results.plot()` in rigorous mode is a direct visualization of this procedure. The **Corrected MI** is simply the y-intercept of the extrapolation line, giving you an estimate of what the MI would be if you could collect an infinite amount of data.