{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 5: Uncovering Latent Dimensionality\n", "\n", "This tutorial covers one of the interesting applications of mutual information: discovering the **latent dimensionality** of a system. This technique allows us to ask not just *how much* information is shared, but *how complex* that shared information is.\n", "\n", "We will explore two related but distinct scientific questions:\n", "\n", "1. **Shared Dimensionality:** What is the dimensionality of the information shared *between* two variables, X and Y?\n", "2. **Internal Dimensionality:** What is the intrinsic complexity of a *single*, high-dimensional neural population?\n", "\n", "We will see how `NeuralMI` can answer both questions and learn why the choice of MI estimator (`InfoNCE` vs. `SMILE`) and model type (`variational`) is critical for getting the right answer." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import neural_mi as nmi\n", "import matplotlib.pyplot as plt\n", "from matplotlib.ticker import MaxNLocator\n", "import seaborn as sns\n", "\n", "sns.set_context(\"talk\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Shared Dimensionality between X and Y\n", "\n", "Let's start with a common question: if X and Y are related, what is the dimensionality of their shared relationship? We can investigate this by performing a simple sweep over the `embedding_dim` of a `SeparableCritic`.\n", "\n", "The `embedding_dim` acts as an **information bottleneck**. We expect the measured MI to rise as this bottleneck widens and then **plateau** once it's large enough to capture the full shared signal. The location of this plateau is our estimate of the shared dimensionality.\n", "\n", "For this task, the default **InfoNCE** estimator is often an excellent choice due to its stability and low variance." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# --- Generate Data ---\n", "# We'll create 100D X and Y variables that non-linearly share a 7D latent signal.\n", "true_latent_dim = 7\n", "x_raw, y_raw = nmi.datasets.generate_nonlinear_from_latent(\n", " n_samples=5000, \n", " latent_dim=true_latent_dim,\n", " observed_dim=100,\n", " mi=3.0 # The MI between the latent variables\n", ")\n", "x_raw_transposed = x_raw.T.detach()\n", "y_raw_transposed = y_raw.T.detach()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-10-20 00:06:19 - neural_mi - WARNING - Reproducibility with random_seed is not guaranteed with n_workers > 1.\n", "2025-10-20 00:06:19 - neural_mi - INFO - Starting parameter sweep with 4 workers...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "63fa9f28f0094dcea233b22e30c94b2b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Parameter Sweep Progress: 0%| | 0/39 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax = shared_dim_results.plot(show=False)\n", "ax.axhline(y=3.0, color='black', linestyle='-', label=f'True MI = 3 bits')\n", "ax.axvline(x=true_latent_dim, color='black', linestyle='--', label=f'True Shared Dim ({true_latent_dim})')\n", "ax.set_title('MI vs. Shared Embedding Dimension (InfoNCE)')\n", "ax.legend()\n", "ax.set_ylim(bottom=0)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Internal Dimensionality of a Single Population\n", "\n", "Now for a more advanced question: what is the internal complexity of a *single* variable `X`? For this, we use `mode='dimensionality'`. This mode automatically splits the channels of `X` into two random halves (`X_A` and `X_B`) and measures the \"Internal Information\" `I(X_A; X_B)`.\n", "\n", "As we've discussed, this is a special, high-information scenario. It's the perfect place to compare our different estimators." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Analysis with InfoNCE vs. SMILE\n", "\n", "Let's run the analysis twice: once with the default `InfoNCE` estimator, and once with the less-biased `SMILE` estimator." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Running with InfoNCE ---\n", "2025-10-20 00:07:03 - neural_mi - WARNING - Reproducibility with random_seed is not guaranteed with n_workers > 1.\n", "2025-10-20 00:07:03 - neural_mi - WARNING - Using 'infonce' estimator for dimensionality analysis. For this specific mode, consider using the 'smile' estimator, as its lower bias may reveal the saturation point more clearly.\n", "2025-10-20 00:07:03 - neural_mi - INFO - --- Running Split 1/3 ---\n", "2025-10-20 00:07:03 - neural_mi - INFO - Starting parameter sweep with 4 workers...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b152e778f2664179acbc9a1007339d96", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Parameter Sweep Progress: 0%| | 0/13 [00:00 1.\n", "2025-10-20 00:08:17 - neural_mi - INFO - --- Running Split 1/3 ---\n", "2025-10-20 00:08:17 - neural_mi - INFO - Starting parameter sweep with 4 workers...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "671fa4cdf3264b3fb3690a8b0d9e596c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Parameter Sweep Progress: 0%| | 0/13 [00:00 1.\n", "2025-10-20 00:09:27 - neural_mi - INFO - --- Running Split 1/3 ---\n", "2025-10-20 00:09:27 - neural_mi - INFO - Starting parameter sweep with 4 workers...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fe5ec77a8c94a9a898bb913624ecc78", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Parameter Sweep Progress: 0%| | 0/13 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(1, 1, figsize=(12, 7))\n", "\n", "# Plot InfoNCE results\n", "df_i = dim_results_infonce.dataframe\n", "ax.plot(df_i['embedding_dim'], df_i['mi_mean'], 'o-', label='Internal Info (InfoNCE)', alpha=0.7)\n", "ax.fill_between(df_i['embedding_dim'], df_i['mi_mean'] - df_i['mi_std'], df_i['mi_mean'] + df_i['mi_std'], alpha=0.1)\n", "\n", "# Plot SMILE results\n", "df_s = dim_results_smile.dataframe\n", "ax.plot(df_s['embedding_dim'], df_s['mi_mean'], 'o-', label='Internal Info (SMILE)', alpha=0.8)\n", "ax.fill_between(df_s['embedding_dim'], df_s['mi_mean'] - df_s['mi_std'], df_s['mi_mean'] + df_s['mi_std'], alpha=0.15)\n", "\n", "# Plot Variational SMILE results\n", "df_v = dim_results_smile_var.dataframe\n", "ax.plot(df_v['embedding_dim'], df_v['mi_mean'], 'o-', label='Internal Info (Variational SMILE)')\n", "ax.fill_between(df_v['embedding_dim'], df_v['mi_mean'] - df_v['mi_std'], df_v['mi_mean'] + df_v['mi_std'], alpha=0.2)\n", "\n", "ax.axvline(x=true_latent_dim, color='black', linestyle='--', label=f'True Dim ({true_latent_dim})')\n", "ax.set_title('Estimator Comparison for Internal Dimensionality')\n", "ax.set_xlabel('Embedding Dimension')\n", "ax.set_ylabel('Internal Information (bits)')\n", "ax.legend()\n", "ax.grid(True, linestyle=':')\n", "ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", "ax.set_ylim(bottom=0)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Final Recommendations:\n", "- To find the dimensionality of the **shared signal between X and Y**, a standard `mode='sweep'` with the default **`InfoNCE`** estimator is a great choice.\n", "- To find the **internal dimensionality of a single population `X`**, use `mode='dimensionality'` and start with the **`SMILE`** estimator for a less biased result.\n", "- For particularly high-dimensional or complex data, consider using **`use_variational=True`** with SMILE to get the most stable and reliable curve.\n", "\n", "*Note*: The variational estimators often require slightly more training to reach good estimates, thus consider increasing the number of epochs and patience when using them." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:base] *", "language": "python", "name": "conda-base-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 4 }