NIR Interoperability¶
NIR (Neuromorphic Intermediate Representation) enables cross-framework SNN model exchange between simulators and neuromorphic hardware platforms.
Export¶
export_to_nir ¶
export_to_nir(layers: List[Tuple[str, Module]], input_shape: tuple = None, dt: float = DEFAULT_DT) -> 'nir.NIRGraph'
Export a list of mlx-snn layers to a NIR graph.
Since mlx-snn models are typically custom nn.Module compositions
(not a standard sequential container), this function takes an explicit
list of (name, module) pairs defining the forward-pass order.
Args:
layers: Ordered list of (name, module) tuples. Supported module
types: nn.Linear, Leaky, IF, Synaptic.
input_shape: Shape of the input tensor (excluding batch dimension).
If None, inferred from the first Linear layer's input features.
dt: Simulation timestep for continuous-time conversion.
Returns:
A nir.NIRGraph representing the model.
Raises: TypeError: If an unsupported module type is encountered. ValueError: If neuron feature count cannot be inferred.
Examples: >>> import mlx.nn as nn, mlxsnn >>> layers = [ ... ('fc1', nn.Linear(784, 128)), ... ('lif1', mlxsnn.Leaky(beta=0.9)), ... ('fc2', nn.Linear(128, 10)), ... ('lif2', mlxsnn.Leaky(beta=0.9)), ... ] >>> graph = mlxsnn.export_to_nir(layers)
Import¶
import_from_nir ¶
Import a NIR graph into an mlx-snn sequential model.
Args:
graph: A nir.NIRGraph to convert.
dt: Simulation timestep for continuous-time to discrete-time
conversion.
Returns:
A NIRSequential model with layers matching the NIR graph.
Raises: ValueError: If the graph contains a cycle.
Examples: >>> import nir, mlxsnn >>> graph = nir.read('model.nir') >>> model = mlxsnn.import_from_nir(graph)
NIR Sequential Model¶
NIRSequential ¶
Bases: Module
Sequential model built from an imported NIR graph.
Stores layers as named attributes and runs them in topological order. Neuron layers receive and return state dicts; linear layers are stateless.
Args:
layer_names: Ordered list of layer names.
layers: Dict mapping names to nn.Module instances.
Examples: >>> model = import_from_nir(graph) >>> state = model.init_states(batch_size=32) >>> out, state = model(x, state)
__call__ ¶
Forward pass through all layers.
Args:
x: Input tensor [batch, features].
states: Dict mapping neuron layer names to their state dicts.
Returns:
Tuple of (output, new_states) where new_states has the same
keys as the input states.
init_states ¶
Initialize states for all neuron layers.
Args: batch_size: Batch size for state tensors.
Returns: Dict mapping neuron layer names to initialized state dicts.
Supported Conversions¶
| mlx-snn | NIR | Direction |
|---|---|---|
nn.Linear |
nir.Affine / nir.Linear |
Export & Import |
Leaky |
nir.LIF |
Export & Import |
IF |
nir.IF |
Export & Import |
Synaptic |
nir.CubaLIF |
Export & Import |
Example¶
Export¶
import mlx.nn as nn
import mlxsnn, nir
layers = [
('fc1', nn.Linear(784, 128)),
('lif1', mlxsnn.Leaky(beta=0.9)),
('fc2', nn.Linear(128, 10)),
('lif2', mlxsnn.Leaky(beta=0.9)),
]
graph = mlxsnn.export_to_nir(layers)
nir.write('model.nir', graph)