Skip to content

Utilities

State management and visualization utilities.

State Management

init_states

init_states(model, batch_size: int) -> dict

Initialize hidden states for all spiking neuron layers.

Walks through the model's named attributes and calls init_state on every SpikingNeuron found, inferring the feature count from a preceding Linear layer when possible.

Args: model: An nn.Module containing spiking neuron layers. batch_size: Batch size for state tensors.

Returns: A dictionary mapping attribute names to their initial state dicts (e.g., {"lif1": {"mem": ...}, ...}).

Examples: >>> import mlx.nn as nn >>> import mlxsnn >>> class SNN(nn.Module): ... def init(self): ... super().init() ... self.fc1 = nn.Linear(784, 128) ... self.lif1 = mlxsnn.Leaky(beta=0.9) ... def call(self, x, states): ... x = self.fc1(x) ... spk, states["lif1"] = self.lif1(x, states["lif1"]) ... return spk, states >>> model = SNN() >>> states = mlxsnn.init_states(model, batch_size=32) >>> print(states["lif1"]["mem"].shape) [32, 128]

reset_states

reset_states(model) -> None

Reset all stateful neuron layers in a model.

Walks through the model's attributes and calls init_state on any SpikingNeuron instance, but this function is provided primarily as a convenience reminder — in mlx-snn, state is always explicit (passed in and returned), so resetting simply means re-creating initial states via init_states.

Args: model: An nn.Module whose spiking neuron layers should be identified.

Note: Since mlx-snn uses explicit state dicts rather than hidden instance variables, this function is a no-op convenience. Use init_states to create fresh state dicts instead.

Visualization

Optional dependency

Visualization requires matplotlib:

pip install mlx-snn[viz]

Spike Raster Plot

plot_raster

plot_raster(spikes, neuron_ids: Optional[list] = None, ax=None, title: str = 'Spike Raster', xlabel: str = 'Timestep', ylabel: str = 'Neuron', marker: str = '|', markersize: float = 2.0, color: str = 'black', show: bool = True)

Plot a spike raster diagram.

Args: spikes: Spike tensor [T, neurons] or [T, batch, neurons]. If 3-D, the first batch element is plotted. neuron_ids: Optional list of neuron indices to plot. ax: Existing matplotlib Axes to draw on. If None, a new figure is created. title: Plot title. xlabel: X-axis label. ylabel: Y-axis label. marker: Marker style for spikes. markersize: Size of spike markers. color: Color of spike markers. show: Whether to call plt.show() after plotting.

Returns: matplotlib Axes object.

Examples: >>> import mlx.core as mx >>> from mlxsnn.utils.visualization import plot_raster >>> spikes = (mx.random.uniform(shape=(50, 20)) > 0.8).astype(mx.float32) >>> ax = plot_raster(spikes, show=False)

Membrane Potential Traces

plot_membrane

plot_membrane(mem, neuron_ids: Optional[list] = None, threshold: Optional[float] = 1.0, ax=None, title: str = 'Membrane Potential', xlabel: str = 'Timestep', ylabel: str = 'Membrane Potential', show: bool = True)

Plot membrane potential traces over time.

Args: mem: Membrane potential [T, neurons] or [T, batch, neurons]. If 3-D, the first batch element is plotted. neuron_ids: Optional list of neuron indices to plot. If None, plots all neurons (up to 10). threshold: If not None, draws a horizontal dashed line at the threshold value. ax: Existing matplotlib Axes. title: Plot title. xlabel: X-axis label. ylabel: Y-axis label. show: Whether to call plt.show().

Returns: matplotlib Axes object.

Examples: >>> import mlx.core as mx >>> from mlxsnn.utils.visualization import plot_membrane >>> mem = mx.random.uniform(shape=(50, 5)) >>> ax = plot_membrane(mem, show=False)

Firing Rate Bar Chart

plot_firing_rate

plot_firing_rate(spk_out, ax=None, title: str = 'Firing Rate', xlabel: str = 'Neuron', ylabel: str = 'Firing Rate', color: str = 'steelblue', show: bool = True)

Plot per-neuron firing rate as a bar chart.

Args: spk_out: Spike tensor [T, neurons] or [T, batch, neurons]. If 3-D, averages over the batch dimension. ax: Existing matplotlib Axes. title: Plot title. xlabel: X-axis label. ylabel: Y-axis label. color: Bar color. show: Whether to call plt.show().

Returns: matplotlib Axes object.

Examples: >>> import mlx.core as mx >>> from mlxsnn.utils.visualization import plot_firing_rate >>> spk = (mx.random.uniform(shape=(100, 20)) > 0.7).astype(mx.float32) >>> ax = plot_firing_rate(spk, show=False)