{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 6: Choosing the Right Model and Estimator\n", "\n", "So far, we've learned how to prepare our data and optimize processing parameters like `window_size`. Now, we need to look inside the 'black box' and understand the two most important choices that determine the success of our MI estimate:\n", "\n", "1. **The Critic Architecture**: The neural network that *compares* the data from X and Y.\n", "2. **The MI Estimator**: The loss function that *trains* the critic.\n", "\n", "Getting these right is key to capturing the true nature of the relationship in your data. This tutorial will provide the intuition and practical examples to guide your choices." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import neural_mi as nmi\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from torch.nn import Sequential, Linear, Softplus\n", "\n", "sns.set_context(\"talk\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1: Choosing the Critic Architecture\n", "\n", "Think of the critic's job as being a sophisticated comparison function, `f(x, y)`. The complexity of this function determines the kinds of relationships the model can find. `NeuralMI` provides three main critic types, each with a different balance of computational cost and expressive power." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| Critic Type | How it Works | Power | Cost | Use Case |\n", "| :--- | :--- | :--- | :--- | :--- |\n", "| **`SeparableCritic`** | Compares embeddings with a simple dot product: `g(x) • h(y)`. | Low | Low | **Default choice.** Fast and effective for most relationships. |\n", "| **`BilinearCritic`** | Uses a learnable matrix to compare embeddings: `g(x)ᵀ W h(y)`. | Medium | Medium | Good for more complex relationships, like rotations, without a huge speed penalty. |\n", "| **`ConcatCritic`** | Concatenates inputs `[x, y]` into a single powerful network `f(x, y)`. | High | High | The most powerful, but can be very slow. Use when other critics fail. |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Test: The Rotated Manifold Problem\n", "\n", "To see the difference, we'll create a special dataset. We'll generate a high-dimensional `X` from a latent variable `Z`, and a high-dimensional `Y` from a **rotated** version of `Z`. A simple `SeparableCritic` (dot product) will struggle to see that these are related, but the more powerful critics should succeed. Also note that MI here is bounded by the entropy of `Z` (`I(X;Y)~I(Z;Z)~H(Z)`) which can be infinte. Thus, we are not interested in the exact value of MI, but rather the trends. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 1. Create the shared latent variable Z\n", "n_samples = 5000\n", "z = torch.randn(n_samples, 2)\n", "\n", "# 2. Create a 45-degree rotation matrix for Y's latent variable\n", "angle = np.pi / 4\n", "rotation_matrix = torch.tensor([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], dtype=torch.float32)\n", "z_rotated = z @ rotation_matrix\n", "\n", "# 3. Create nonlinear mappings from latent to a high-dimensional observed space\n", "mlp = Sequential(Linear(2, 64), Softplus(), Linear(64, 50))\n", "x_raw = mlp(z).T.detach()\n", "y_raw = mlp(z_rotated).T.detach()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-10-20 00:11:13 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "42999a57e5f143418872dd5853e003a0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sequential Sweep Progress: 0%| | 0/15 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "sns.barplot(data=critic_results.dataframe, x='critic_type', y='mi_mean', capsize=0.1, order=['separable', 'bilinear', 'concat'])\n", "plt.title('Critic Performance on the Rotated Manifold Task')\n", "plt.ylabel('Estimated MI (bits)')\n", "plt.xlabel('Critic Architecture')\n", "plt.ylim(bottom=0)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2: Choosing the MI Estimator\n", "\n", "The estimator is the loss function used to train the critic. The choice of estimator involves a crucial trade-off between **bias** and **variance**.\n", "\n", "| Estimator | Bias | Variance | Use Case |\n", "| :--- | :--- | :--- | :--- |\n", "| **`InfoNCE`** | High (Biased Low) | Low | **Default choice.** Very stable. Excellent for most tasks where the true MI isn't extremely high. |\n", "| **`SMILE`** | Low | Medium | Less biased. Use when you suspect the true MI is very high, as it can avoid the artificial ceiling that affects `InfoNCE`. |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Problem: The `InfoNCE` Upper Bound\n", "\n", "The `InfoNCE` estimator is mathematically bounded by `log(batch_size)`. This means it can never report an MI value higher than this limit. For a batch size of 128, the limit is `log(128) ≈ 4.85` nats or `6.99` bits. If the true MI is higher than this, `InfoNCE` will underestimate it.\n", "\n", "Let's create a dataset with a known **ground truth MI of 8.0 bits** and see how the two estimators perform." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Running with InfoNCE (default) ---\n", "2025-10-20 00:21:07 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "761dc87722db4729a52252b9aa9cbe2b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sequential Sweep Progress: 0%| | 0/1 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)\n", "\n", "sns.barplot(\n", " data=critic_results.dataframe,\n", " x='critic_type',\n", " y='mi_mean',\n", " capsize=0.1,\n", " order=['separable', 'bilinear', 'concat'],\n", " ax=axes[0]\n", ")\n", "axes[0].set_title('InfoNCE')\n", "axes[0].set_ylabel('Estimated MI (bits)')\n", "axes[0].set_xlabel('Critic Architecture')\n", "axes[0].set_ylim(bottom=0)\n", "\n", "sns.barplot(\n", " data=critic_results_smile.dataframe,\n", " x='critic_type',\n", " y='mi_mean',\n", " capsize=0.1,\n", " order=['separable', 'bilinear', 'concat'],\n", " ax=axes[1],\n", " color='red'\n", ")\n", "axes[1].set_title('SMILE')\n", "axes[1].set_xlabel('Critic Architecture')\n", "axes[1].set_ylim(-0.1, 1.2*critic_results_smile.dataframe.mi_mean.max())\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The New Result: A Clear Hierarchy\n", "\n", "The new results show a clear hierarchy in power. The `SeparableCritic` finds the lowest MI. The `BilinearCritic` does better, as its learnable matrix `W` can effectively undo the rotation. The `ConcatCritic`, being the most powerful, also succeeds, achieving a significantly higher MI at the cost of being slower." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion and Recommendations\n", "\n", "Choosing the right architecture and estimator is a trade-off. Here's a simple guide to get started:\n", "\n", "> **Recommendation:** Always start with the default: a **`SeparableCritic`** and the **`InfoNCE`** estimator. This combination is fast, stable, and works for a wide variety of problems.\n", "\n", "- If you have reason to suspect a complex, non-linear relationship (like a rotation or other geometric transformation), try the **`BilinearCritic`**. It offers a significant power boost without the high cost of the `ConcatCritic`.\n", "\n", "- If you are estimating the internal dimensionality of a system (`mode='dimensionality'`) or have other reasons to believe the true MI is very high, switch to the **`SMILE`** estimator to get a less biased result.\n", "\n", "With these guidelines, you are now equipped to make informed decisions about the core components of your MI analysis. In the next tutorial, we will tackle the final and most important step for scientific rigor: correcting for finite-sampling bias to get a statstically sound result.\n", "\n", "*A note*: While the theoretical bound of `InfoNCE` is `log(batch_size)`, this is the case during the *training*. i.e., we should never see an MI value when we're training the estimator that exceeds this bound. However, when evaluating on the whole dataset, the MI value can be bigger, as now the batch size is bigger." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:base] *", "language": "python", "name": "conda-base-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 4 }