Skip to content

Layers

Composite spiking layers that combine standard operations (convolution, pooling) with spiking neuron dynamics.

SpikingConv2d

Convolutional layer with integrated LIF neuron — performs Conv2d → LIF in a single module.

SpikingConv2d

Bases: Module

Conv2d + optional BatchNorm + spiking neuron.

Wraps mx.nn.Conv2d with a spiking neuron so that each forward call processes a single timestep. The neuron state must be carried across timesteps by the caller.

MLX Conv2d uses NHWC layout: input (B, H, W, C_in) -> output (B, H', W', C_out)

Args: in_channels: Number of input channels. out_channels: Number of output channels (filters). kernel_size: Convolution kernel size (int or tuple). stride: Convolution stride (default 1). padding: Convolution padding (default 0). bias: If True, add a learnable bias to the convolution. bn: If True, apply batch normalization before the neuron. neuron: Neuron type string ('leaky', 'if', 'synaptic') or a pre-constructed SpikingNeuron instance. neuron_params: Dict of keyword arguments forwarded to the neuron constructor (e.g. {'beta': 0.95, 'threshold': 0.5}).

Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingConv2d >>> layer = SpikingConv2d(3, 16, kernel_size=3, padding=1) >>> state = layer.init_state(batch_size=4, spatial_shape=(8, 8)) >>> x = mx.ones((4, 8, 8, 3)) >>> spk, state = layer(x, state) >>> spk.shape [4, 8, 8, 16]

init_state

init_state(batch_size: int, spatial_shape: tuple) -> dict

Initialize neuron state for spatial feature maps.

Args: batch_size: Number of samples in the batch. spatial_shape: (H', W') spatial dimensions after the convolution (accounting for stride and padding).

Returns: State dict whose tensors have shape (B, H', W', C_out) in NHWC format.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Args: x: Input tensor (B, H, W, C_in) in NHWC format. state: Neuron state dict from the previous timestep.

Returns: Tuple of (spikes, new_state) where spikes has shape (B, H', W', C_out).

SpikingMaxPool2d

Max pooling that preserves spike semantics.

SpikingMaxPool2d

Bases: Module

Max pooling that operates on spike tensors.

Wraps mx.nn.MaxPool2d. Because spike tensors are binary (0 or 1), max pooling preserves a spike if any input in the pooling window fired.

Input and output use NHWC layout: (B, H, W, C).

Args: kernel_size: Size of the pooling window (int or tuple). stride: Stride of the pooling window. Defaults to kernel_size when None. padding: Zero-padding added to both sides (default 0).

Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingMaxPool2d >>> pool = SpikingMaxPool2d(kernel_size=2, stride=2) >>> x = mx.ones((4, 8, 8, 16)) >>> out = pool(x) >>> out.shape [4, 4, 4, 16]

__call__

__call__(x: array) -> mx.array

Apply max pooling.

Args: x: Input tensor (B, H, W, C) in NHWC format.

Returns: Pooled tensor (B, H', W', C).

SpikingAvgPool2d

Average pooling for spike feature maps.

SpikingAvgPool2d

Bases: Module

Average pooling for spiking networks.

Wraps mx.nn.AvgPool2d. When applied to binary spike tensors the output represents the local firing density within each pooling window.

Input and output use NHWC layout: (B, H, W, C).

Args: kernel_size: Size of the pooling window (int or tuple). stride: Stride of the pooling window. Defaults to kernel_size when None. padding: Zero-padding added to both sides (default 0).

Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingAvgPool2d >>> pool = SpikingAvgPool2d(kernel_size=2, stride=2) >>> x = mx.ones((4, 8, 8, 16)) >>> out = pool(x) >>> out.shape [4, 4, 4, 16]

__call__

__call__(x: array) -> mx.array

Apply average pooling.

Args: x: Input tensor (B, H, W, C) in NHWC format.

Returns: Pooled tensor (B, H', W', C).

SpikingFlatten

Flatten spatial dimensions for transition from conv to FC layers.

SpikingFlatten

Bases: Module

Flatten spatial dimensions of a spiking feature map.

Reshapes (B, H, W, C) to (B, H*W*C). This is a stateless layer that carries no learnable parameters.

Args: start_dim: First dimension to flatten (default 1, preserving the batch dimension).

Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingFlatten >>> flat = SpikingFlatten() >>> x = mx.ones((4, 8, 8, 16)) >>> out = flat(x) >>> out.shape [4, 1024]

__call__

__call__(x: array) -> mx.array

Flatten the tensor from start_dim onward.

Args: x: Input tensor of arbitrary shape.

Returns: Flattened tensor with all dims from start_dim merged.

SpikeDropout

Dropout specialized for binary spike trains. Drops spikes with probability p during training (no rescaling, since spikes are binary).

SpikeDropout

Bases: Module

Drops spikes with probability p during training.

During evaluation (model.eval()), all spikes pass through unchanged. Unlike nn.Dropout, no rescaling is applied: spikes are binary, so rescaling would produce non-binary values.

Args: p: Probability of dropping each spike. Must be in [0, 1).

Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikeDropout >>> drop = SpikeDropout(p=0.3) >>> drop.train() >>> x = mx.ones((4, 10)) >>> out = drop(x) >>> out.shape [4, 10]

__call__

__call__(x: array) -> mx.array

Apply spike dropout.

Args: x: Spike tensor of any shape (typically binary {0, 1}).

Returns: Tensor with spikes randomly zeroed during training.

Neuron Factory

create_neuron

create_neuron(neuron_type: str | SpikingNeuron, **kwargs) -> SpikingNeuron

Create a spiking neuron from a type string or return an existing instance.

Args: neuron_type: A string identifier ('leaky', 'if', 'synaptic', 'alpha', 'alif', 'izhikevich') or an already-constructed SpikingNeuron. **kwargs: Keyword arguments forwarded to the neuron constructor when neuron_type is a string.

Returns: A SpikingNeuron instance.

Raises: ValueError: If neuron_type is an unrecognised string. TypeError: If neuron_type is neither a string nor a SpikingNeuron.