Tutorial 7: Advanced Customization with Custom Models

Welcome to our final tutorial! You have now mastered the main workflows of NeuralMI, from simple estimates to rigorous, publication-ready analyses. But what happens when your research requires a model architecture that isn’t built into the library?

This tutorial is for the advanced user who wants maximum flexibility. We will show you how to define your own models using PyTorch and seamlessly integrate them into the nmi.run pipeline, ensuring that NeuralMI can grow with your research needs.

1. The Requirements for a Custom Model

To be compatible with the NeuralMI trainer, any custom model must meet two simple requirements:

  1. It must inherit from ``nmi.models.BaseCritic``.

  2. Its forward method must accept two arguments, x and y, and return a tuple containing:

    • scores: A (batch_size, batch_size) tensor of similarity scores.

    • kl_loss: A scalar tensor for any KL divergence loss. If not using a variational model, this should be torch.tensor(0.0).

Let’s explore the two main ways to achieve this.

[1]:
import torch
import torch.nn as nn
import numpy as np
import neural_mi as nmi
import seaborn as sns

sns.set_context("talk")

2. Method 1: Full Control with custom_critic

This method gives you complete control. You define the entire critic architecture from scratch and pass a pre-initialized instance of your model to nmi.run.

Let’s build a simple custom critic that uses a linear embedding layer.

[2]:
# Define a simple embedding model (can be anything that inherits from BaseEmbedding)
class LinearEmbedding(nmi.models.BaseEmbedding):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.layer = nn.Linear(input_dim, embedding_dim)

    def forward(self, x):
        x_flat = x.view(x.shape[0], -1)
        return self.layer(x_flat)

# Define our custom critic that uses the embedding model
class MyCustomSeparableCritic(nmi.models.BaseCritic):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.embedding_net = LinearEmbedding(input_dim, embedding_dim)

    def forward(self, x, y):
        x_embedded = self.embedding_net(x)
        y_embedded = self.embedding_net(y)
        scores = torch.matmul(x_embedded, y_embedded.t())
        return scores, torch.tensor(0.0, device=scores.device)

print("Custom critic class defined successfully!")
Custom critic class defined successfully!

Using the Custom Critic in nmi.run

Using our new model is simple: we instantiate our critic and pass the instance to the custom_critic argument. The library will then skip its internal model-building logic and use our model directly. Any model architecture parameters in base_params (like embedding_dim, hidden_dim, etc.) will be ignored.

[3]:
# --- Generate some simple data ---
x_raw, y_raw = nmi.datasets.generate_correlated_gaussians(n_samples=5000, dim=5, mi=2.0)

# --- Instantiate our model ---
my_critic_instance = MyCustomSeparableCritic(input_dim=5, embedding_dim=16)

# --- Define trainer parameters (no model architecture params needed) ---
base_params = {
    'n_epochs': 50, 'learning_rate': 1e-3, 'batch_size': 128,
    'patience': 10
}

# --- Run the estimation ---
results = nmi.run(
    x_data=x_raw.T, y_data=y_raw.T,
    mode='estimate',
    processor_type_x='continuous',
    processor_params_x={'window_size': 1},
    base_params=base_params,
    split_mode='random',
    custom_critic=my_critic_instance, # Here is the magic!
    n_workers=1,
    random_seed=42
)

print(f"\n--- Results with custom_critic ---")
print(f"Ground Truth MI:  2.000 bits")
print(f"Estimated MI:     {results.mi_estimate:.3f} bits")
2025-10-20 00:10:41 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...
2025-10-20 00:10:45 - neural_mi - INFO - Parameter sweep finished.

--- Results with custom_critic ---
Ground Truth MI:  2.000 bits
Estimated MI:     1.579 bits

3. Method 2: Modular Control with custom_embedding_cls

Sometimes you don’t need to reinvent the wheel. You might like the library’s built-in SeparableCritic, but you just want to swap out the embedding model (e.g., use a Transformer instead of an MLP).

The custom_embedding_cls parameter is perfect for this. Instead of a model instance, you provide the class of your custom embedding model. The library will then handle instantiating it for you, using the architecture parameters from base_params.

Important: For this to work, your custom embedding’s __init__ method must be designed to accept the standard parameters that the library provides: input_dim, hidden_dim, embed_dim, and n_layers.

[4]:
# Define a more complex custom embedding that is compatible with the library's builder
class CustomMLP(nmi.models.BaseEmbedding):
    # This __init__ signature matches the arguments the library's internal builder will provide
    def __init__(self, input_dim: int, hidden_dim: int, embed_dim: int, n_layers: int, activation: str = 'relu'):
        super().__init__()

        # You can define any architecture you want inside
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(n_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        self.network = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.output_layer(self.network(x.view(x.shape[0], -1)))

print("CustomMLP class defined successfully!")
CustomMLP class defined successfully!
[5]:
# --- Define model and trainer parameters ---
# This time, we DO need to provide the architecture params, as the library will use them
# to instantiate our CustomMLP class.
base_params_cls = {
    'n_epochs': 50, 'learning_rate': 1e-3, 'batch_size': 128,
    'patience': 10, 'embedding_dim': 16, 'hidden_dim': 64, 'n_layers': 2,
    'critic_type': 'separable'
}

# --- Run the estimation ---
results_cls = nmi.run(
    x_data=x_raw.T, y_data=y_raw.T,
    mode='estimate',
    processor_type_x='continuous',
    processor_params_x={'window_size': 1},
    split_mode='random',
    base_params=base_params_cls,
    custom_embedding_cls=CustomMLP, # Pass the CLASS here
    n_workers=1,
    random_seed=42
)

print(f"\n--- Results with custom_embedding_cls ---")
print(f"Ground Truth MI:  2.000 bits")
print(f"Estimated MI:     {results_cls.mi_estimate:.3f} bits")
2025-10-20 00:10:45 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...
2025-10-20 00:10:48 - neural_mi - INFO - Parameter sweep finished.

--- Results with custom_embedding_cls ---
Ground Truth MI:  2.000 bits
Estimated MI:     1.944 bits

Success! The estimate is probably more accurate. This modular approach allows you to leverage the library’s tested critic architectures while still having the freedom to design novel embedding models for your specific data.

4. Conclusion

Congratulations! You have completed the NeuralMI learning path. You now have the skills to handle complex neural data, choose the right model architecture, perform scientifically rigorous analyses, and even extend the library with your own custom models.

The custom_critic and custom_embedding_cls features provide escape hatches for maximum flexibility, ensuring that NeuralMI can serve as the foundation for your analysis, no matter how specialized your research becomes.