Skip to content

NIR Interoperability

NIR (Neuromorphic Intermediate Representation) enables cross-framework SNN model exchange between simulators and neuromorphic hardware platforms.

Optional dependency

pip install mlx-snn[nir]

Export

export_to_nir

export_to_nir(layers: List[Tuple[str, Module]], input_shape: tuple = None, dt: float = DEFAULT_DT) -> 'nir.NIRGraph'

Export a list of mlx-snn layers to a NIR graph.

Since mlx-snn models are typically custom nn.Module compositions (not a standard sequential container), this function takes an explicit list of (name, module) pairs defining the forward-pass order.

Args: layers: Ordered list of (name, module) tuples. Supported module types: nn.Linear, Leaky, IF, Synaptic. input_shape: Shape of the input tensor (excluding batch dimension). If None, inferred from the first Linear layer's input features. dt: Simulation timestep for continuous-time conversion.

Returns: A nir.NIRGraph representing the model.

Raises: TypeError: If an unsupported module type is encountered. ValueError: If neuron feature count cannot be inferred.

Examples: >>> import mlx.nn as nn, mlxsnn >>> layers = [ ... ('fc1', nn.Linear(784, 128)), ... ('lif1', mlxsnn.Leaky(beta=0.9)), ... ('fc2', nn.Linear(128, 10)), ... ('lif2', mlxsnn.Leaky(beta=0.9)), ... ] >>> graph = mlxsnn.export_to_nir(layers)

Import

import_from_nir

import_from_nir(graph: 'nir.NIRGraph', dt: float = DEFAULT_DT) -> NIRSequential

Import a NIR graph into an mlx-snn sequential model.

Args: graph: A nir.NIRGraph to convert. dt: Simulation timestep for continuous-time to discrete-time conversion.

Returns: A NIRSequential model with layers matching the NIR graph.

Raises: ValueError: If the graph contains a cycle.

Examples: >>> import nir, mlxsnn >>> graph = nir.read('model.nir') >>> model = mlxsnn.import_from_nir(graph)

NIR Sequential Model

NIRSequential

Bases: Module

Sequential model built from an imported NIR graph.

Stores layers as named attributes and runs them in topological order. Neuron layers receive and return state dicts; linear layers are stateless.

Args: layer_names: Ordered list of layer names. layers: Dict mapping names to nn.Module instances.

Examples: >>> model = import_from_nir(graph) >>> state = model.init_states(batch_size=32) >>> out, state = model(x, state)

__call__

__call__(x: array, states: dict) -> Tuple[mx.array, dict]

Forward pass through all layers.

Args: x: Input tensor [batch, features]. states: Dict mapping neuron layer names to their state dicts.

Returns: Tuple of (output, new_states) where new_states has the same keys as the input states.

init_states

init_states(batch_size: int) -> dict

Initialize states for all neuron layers.

Args: batch_size: Batch size for state tensors.

Returns: Dict mapping neuron layer names to initialized state dicts.

Supported Conversions

mlx-snn NIR Direction
nn.Linear nir.Affine / nir.Linear Export & Import
Leaky nir.LIF Export & Import
IF nir.IF Export & Import
Synaptic nir.CubaLIF Export & Import

Example

Export

import mlx.nn as nn
import mlxsnn, nir

layers = [
    ('fc1', nn.Linear(784, 128)),
    ('lif1', mlxsnn.Leaky(beta=0.9)),
    ('fc2', nn.Linear(128, 10)),
    ('lif2', mlxsnn.Leaky(beta=0.9)),
]
graph = mlxsnn.export_to_nir(layers)
nir.write('model.nir', graph)

Import

graph = nir.read('model.nir')
model = mlxsnn.import_from_nir(graph)
state = model.init_states(batch_size=32)
out, state = model(x, state)