{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 7: Advanced Customization with Custom Models\n", "\n", "Welcome to our final tutorial! You have now mastered the main workflows of `NeuralMI`, from simple estimates to rigorous, publication-ready analyses. But what happens when your research requires a model architecture that isn't built into the library? \n", "\n", "This tutorial is for the advanced user who wants maximum flexibility. We will show you how to define your own models using PyTorch and seamlessly integrate them into the `nmi.run` pipeline, ensuring that `NeuralMI` can grow with your research needs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. The Requirements for a Custom Model\n", "\n", "To be compatible with the `NeuralMI` trainer, any custom model must meet two simple requirements:\n", "\n", "1. It must inherit from **`nmi.models.BaseCritic`**.\n", "2. Its `forward` method must accept two arguments, `x` and `y`, and return a **tuple** containing:\n", " - `scores`: A `(batch_size, batch_size)` tensor of similarity scores.\n", " - `kl_loss`: A scalar tensor for any KL divergence loss. If not using a variational model, this should be `torch.tensor(0.0)`.\n", "\n", "Let's explore the two main ways to achieve this." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "import neural_mi as nmi\n", "import seaborn as sns\n", "\n", "sns.set_context(\"talk\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Method 1: Full Control with `custom_critic`\n", "\n", "This method gives you complete control. You define the entire critic architecture from scratch and pass a pre-initialized **instance** of your model to `nmi.run`.\n", "\n", "Let's build a simple custom critic that uses a linear embedding layer." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Custom critic class defined successfully!\n" ] } ], "source": [ "# Define a simple embedding model (can be anything that inherits from BaseEmbedding)\n", "class LinearEmbedding(nmi.models.BaseEmbedding):\n", " def __init__(self, input_dim, embedding_dim):\n", " super().__init__()\n", " self.layer = nn.Linear(input_dim, embedding_dim)\n", "\n", " def forward(self, x):\n", " x_flat = x.view(x.shape[0], -1)\n", " return self.layer(x_flat)\n", "\n", "# Define our custom critic that uses the embedding model\n", "class MyCustomSeparableCritic(nmi.models.BaseCritic):\n", " def __init__(self, input_dim, embedding_dim):\n", " super().__init__()\n", " self.embedding_net = LinearEmbedding(input_dim, embedding_dim)\n", "\n", " def forward(self, x, y):\n", " x_embedded = self.embedding_net(x)\n", " y_embedded = self.embedding_net(y)\n", " scores = torch.matmul(x_embedded, y_embedded.t())\n", " return scores, torch.tensor(0.0, device=scores.device)\n", "\n", "print(\"Custom critic class defined successfully!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using the Custom Critic in `nmi.run`\n", "\n", "Using our new model is simple: we instantiate our critic and pass the **instance** to the `custom_critic` argument. The library will then skip its internal model-building logic and use our model directly. Any model architecture parameters in `base_params` (like `embedding_dim`, `hidden_dim`, etc.) will be ignored." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-10-20 00:10:41 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b94c5d97570f448384ad8db9cc48c7b2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sequential Sweep Progress: 0%| | 0/1 [00:00 torch.Tensor:\n", " return self.output_layer(self.network(x.view(x.shape[0], -1)))\n", "\n", "print(\"CustomMLP class defined successfully!\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-10-20 00:10:45 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b4a013c02bdc42fca57ff8504b1416e7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sequential Sweep Progress: 0%| | 0/1 [00:00