Training¶
Training utilities for SNN models, including BPTT forward pass helpers.
BPTT Forward Pass¶
bptt_forward ¶
bptt_forward(model, spikes: array, state: dict, num_steps: int | None = None) -> tuple[mx.array, mx.array, dict]
Run BPTT forward pass over time, collecting all outputs.
Iterates the model over the time dimension, collecting output spikes and membrane potentials at each step.
Args:
model: An SNN model (callable with signature
(x, state) -> (spk, new_state)).
spikes: Input spike tensor [T, batch, ...] (time-first).
state: Initial hidden state dict.
num_steps: Override time steps. If None, uses spikes.shape[0].
Returns:
all_spikes: Output spikes [T, batch, ...].
all_mems: Membrane potentials [T, batch, ...].
final_state: Final hidden state dict.
Examples: >>> from mlxsnn.training import bptt_forward >>> all_spk, all_mem, final_state = bptt_forward( ... model, input_spikes, init_state ... )
Compiled Forward Pass¶
Per-timestep compilation via mx.compile for faster inference.
compiled_step ¶
Create a compiled single-timestep forward function.
Wraps the model's __call__ with mx.compile so that each
timestep's forward pass is optimized as a fused computation graph.
Args:
model: An SNN model or neuron with signature
(x, state) -> (spk, new_state).
Returns: A compiled callable with the same signature as the model.
Examples: >>> lif = mlxsnn.Leaky(beta=0.9) >>> step_fn = compiled_step(lif) >>> state = lif.init_state(8, 128) >>> x = mx.random.normal((8, 128)) >>> spk, state = step_fn(x, state)
compiled_forward ¶
compiled_forward(model: Module, spikes: array, state: dict, num_steps: int | None = None) -> tuple[mx.array, mx.array, dict]
BPTT forward pass with per-timestep compilation.
Like :func:~mlxsnn.training.bptt.bptt_forward but uses
mx.compile to optimize each timestep's computation.
Args:
model: An SNN model with (x, state) -> (spk, new_state).
spikes: Input spike tensor [T, batch, ...] (time-first).
state: Initial hidden state dict.
num_steps: Override time steps. If None, uses spikes.shape[0].
Returns:
all_spikes: Output spikes [T, batch, ...].
all_mems: Membrane potentials [T, batch, ...].
final_state: Final hidden state dict.
Note: Compilation is most beneficial for models with expensive per-timestep computation (e.g., conv layers). For simple FC networks the overhead may not be worth it.
Training Pattern¶
The standard training pattern with mlx-snn:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlxsnn
model = MySNN()
optimizer = optim.Adam(learning_rate=1e-3)
def loss_fn(model, x_seq, y):
mem_out = model(x_seq)
return mlxsnn.mse_membrane_loss(mem_out, y)
loss_and_grad = nn.value_and_grad(model, loss_fn)
for epoch in range(num_epochs):
for x_batch, y_batch in dataloader:
x_seq = mx.array(x_batch).transpose(1, 0, 2) # to time-first
y = mx.array(y_batch)
loss, grads = loss_and_grad(model, x_seq, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)