Utilities¶
State management and visualization utilities.
State Management¶
init_states ¶
Initialize hidden states for all spiking neuron layers.
Walks through the model's named attributes and calls
init_state on every SpikingNeuron found, inferring
the feature count from a preceding Linear layer when
possible.
Args:
model: An nn.Module containing spiking neuron layers.
batch_size: Batch size for state tensors.
Returns:
A dictionary mapping attribute names to their initial
state dicts (e.g., {"lif1": {"mem": ...}, ...}).
Examples: >>> import mlx.nn as nn >>> import mlxsnn >>> class SNN(nn.Module): ... def init(self): ... super().init() ... self.fc1 = nn.Linear(784, 128) ... self.lif1 = mlxsnn.Leaky(beta=0.9) ... def call(self, x, states): ... x = self.fc1(x) ... spk, states["lif1"] = self.lif1(x, states["lif1"]) ... return spk, states >>> model = SNN() >>> states = mlxsnn.init_states(model, batch_size=32) >>> print(states["lif1"]["mem"].shape) [32, 128]
reset_states ¶
Reset all stateful neuron layers in a model.
Walks through the model's attributes and calls init_state
on any SpikingNeuron instance, but this function is provided
primarily as a convenience reminder — in mlx-snn, state is
always explicit (passed in and returned), so resetting simply
means re-creating initial states via init_states.
Args:
model: An nn.Module whose spiking neuron layers should
be identified.
Note:
Since mlx-snn uses explicit state dicts rather than hidden
instance variables, this function is a no-op convenience.
Use init_states to create fresh state dicts instead.
Visualization¶
Spike Raster Plot¶
plot_raster ¶
plot_raster(spikes, neuron_ids: Optional[list] = None, ax=None, title: str = 'Spike Raster', xlabel: str = 'Timestep', ylabel: str = 'Neuron', marker: str = '|', markersize: float = 2.0, color: str = 'black', show: bool = True)
Plot a spike raster diagram.
Args:
spikes: Spike tensor [T, neurons] or [T, batch, neurons].
If 3-D, the first batch element is plotted.
neuron_ids: Optional list of neuron indices to plot.
ax: Existing matplotlib Axes to draw on. If None, a new
figure is created.
title: Plot title.
xlabel: X-axis label.
ylabel: Y-axis label.
marker: Marker style for spikes.
markersize: Size of spike markers.
color: Color of spike markers.
show: Whether to call plt.show() after plotting.
Returns:
matplotlib Axes object.
Examples: >>> import mlx.core as mx >>> from mlxsnn.utils.visualization import plot_raster >>> spikes = (mx.random.uniform(shape=(50, 20)) > 0.8).astype(mx.float32) >>> ax = plot_raster(spikes, show=False)
Membrane Potential Traces¶
plot_membrane ¶
plot_membrane(mem, neuron_ids: Optional[list] = None, threshold: Optional[float] = 1.0, ax=None, title: str = 'Membrane Potential', xlabel: str = 'Timestep', ylabel: str = 'Membrane Potential', show: bool = True)
Plot membrane potential traces over time.
Args:
mem: Membrane potential [T, neurons] or [T, batch, neurons].
If 3-D, the first batch element is plotted.
neuron_ids: Optional list of neuron indices to plot. If None,
plots all neurons (up to 10).
threshold: If not None, draws a horizontal dashed line at
the threshold value.
ax: Existing matplotlib Axes.
title: Plot title.
xlabel: X-axis label.
ylabel: Y-axis label.
show: Whether to call plt.show().
Returns:
matplotlib Axes object.
Examples: >>> import mlx.core as mx >>> from mlxsnn.utils.visualization import plot_membrane >>> mem = mx.random.uniform(shape=(50, 5)) >>> ax = plot_membrane(mem, show=False)
Firing Rate Bar Chart¶
plot_firing_rate ¶
plot_firing_rate(spk_out, ax=None, title: str = 'Firing Rate', xlabel: str = 'Neuron', ylabel: str = 'Firing Rate', color: str = 'steelblue', show: bool = True)
Plot per-neuron firing rate as a bar chart.
Args:
spk_out: Spike tensor [T, neurons] or [T, batch, neurons].
If 3-D, averages over the batch dimension.
ax: Existing matplotlib Axes.
title: Plot title.
xlabel: X-axis label.
ylabel: Y-axis label.
color: Bar color.
show: Whether to call plt.show().
Returns:
matplotlib Axes object.
Examples: >>> import mlx.core as mx >>> from mlxsnn.utils.visualization import plot_firing_rate >>> spk = (mx.random.uniform(shape=(100, 20)) > 0.7).astype(mx.float32) >>> ax = plot_firing_rate(spk, show=False)