Skip to content

Training

Training utilities for SNN models, including BPTT forward pass helpers.

BPTT Forward Pass

bptt_forward

bptt_forward(model, spikes: array, state: dict, num_steps: int | None = None) -> tuple[mx.array, mx.array, dict]

Run BPTT forward pass over time, collecting all outputs.

Iterates the model over the time dimension, collecting output spikes and membrane potentials at each step.

Args: model: An SNN model (callable with signature (x, state) -> (spk, new_state)). spikes: Input spike tensor [T, batch, ...] (time-first). state: Initial hidden state dict. num_steps: Override time steps. If None, uses spikes.shape[0].

Returns: all_spikes: Output spikes [T, batch, ...]. all_mems: Membrane potentials [T, batch, ...]. final_state: Final hidden state dict.

Examples: >>> from mlxsnn.training import bptt_forward >>> all_spk, all_mem, final_state = bptt_forward( ... model, input_spikes, init_state ... )

Compiled Forward Pass

Per-timestep compilation via mx.compile for faster inference.

compiled_step

compiled_step(model: Module) -> callable

Create a compiled single-timestep forward function.

Wraps the model's __call__ with mx.compile so that each timestep's forward pass is optimized as a fused computation graph.

Args: model: An SNN model or neuron with signature (x, state) -> (spk, new_state).

Returns: A compiled callable with the same signature as the model.

Examples: >>> lif = mlxsnn.Leaky(beta=0.9) >>> step_fn = compiled_step(lif) >>> state = lif.init_state(8, 128) >>> x = mx.random.normal((8, 128)) >>> spk, state = step_fn(x, state)

compiled_forward

compiled_forward(model: Module, spikes: array, state: dict, num_steps: int | None = None) -> tuple[mx.array, mx.array, dict]

BPTT forward pass with per-timestep compilation.

Like :func:~mlxsnn.training.bptt.bptt_forward but uses mx.compile to optimize each timestep's computation.

Args: model: An SNN model with (x, state) -> (spk, new_state). spikes: Input spike tensor [T, batch, ...] (time-first). state: Initial hidden state dict. num_steps: Override time steps. If None, uses spikes.shape[0].

Returns: all_spikes: Output spikes [T, batch, ...]. all_mems: Membrane potentials [T, batch, ...]. final_state: Final hidden state dict.

Note: Compilation is most beneficial for models with expensive per-timestep computation (e.g., conv layers). For simple FC networks the overhead may not be worth it.

Training Pattern

The standard training pattern with mlx-snn:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlxsnn

model = MySNN()
optimizer = optim.Adam(learning_rate=1e-3)

def loss_fn(model, x_seq, y):
    mem_out = model(x_seq)
    return mlxsnn.mse_membrane_loss(mem_out, y)

loss_and_grad = nn.value_and_grad(model, loss_fn)

for epoch in range(num_epochs):
    for x_batch, y_batch in dataloader:
        x_seq = mx.array(x_batch).transpose(1, 0, 2)  # to time-first
        y = mx.array(y_batch)

        loss, grads = loss_and_grad(model, x_seq, y)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)