Layers¶
Composite spiking layers that combine standard operations (convolution, pooling) with spiking neuron dynamics.
SpikingConv2d¶
Convolutional layer with integrated LIF neuron — performs Conv2d → LIF in a single module.
SpikingConv2d ¶
Bases: Module
Conv2d + optional BatchNorm + spiking neuron.
Wraps mx.nn.Conv2d with a spiking neuron so that each forward
call processes a single timestep. The neuron state must be carried
across timesteps by the caller.
MLX Conv2d uses NHWC layout:
input (B, H, W, C_in) -> output (B, H', W', C_out)
Args:
in_channels: Number of input channels.
out_channels: Number of output channels (filters).
kernel_size: Convolution kernel size (int or tuple).
stride: Convolution stride (default 1).
padding: Convolution padding (default 0).
bias: If True, add a learnable bias to the convolution.
bn: If True, apply batch normalization before the neuron.
neuron: Neuron type string ('leaky', 'if', 'synaptic')
or a pre-constructed SpikingNeuron instance.
neuron_params: Dict of keyword arguments forwarded to the neuron
constructor (e.g. {'beta': 0.95, 'threshold': 0.5}).
Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingConv2d >>> layer = SpikingConv2d(3, 16, kernel_size=3, padding=1) >>> state = layer.init_state(batch_size=4, spatial_shape=(8, 8)) >>> x = mx.ones((4, 8, 8, 3)) >>> spk, state = layer(x, state) >>> spk.shape [4, 8, 8, 16]
init_state ¶
Initialize neuron state for spatial feature maps.
Args:
batch_size: Number of samples in the batch.
spatial_shape: (H', W') spatial dimensions after
the convolution (accounting for stride and padding).
Returns:
State dict whose tensors have shape
(B, H', W', C_out) in NHWC format.
__call__ ¶
Forward one timestep.
Args:
x: Input tensor (B, H, W, C_in) in NHWC format.
state: Neuron state dict from the previous timestep.
Returns:
Tuple of (spikes, new_state) where spikes has shape
(B, H', W', C_out).
SpikingMaxPool2d¶
Max pooling that preserves spike semantics.
SpikingMaxPool2d ¶
Bases: Module
Max pooling that operates on spike tensors.
Wraps mx.nn.MaxPool2d. Because spike tensors are binary
(0 or 1), max pooling preserves a spike if any input in the
pooling window fired.
Input and output use NHWC layout: (B, H, W, C).
Args:
kernel_size: Size of the pooling window (int or tuple).
stride: Stride of the pooling window. Defaults to
kernel_size when None.
padding: Zero-padding added to both sides (default 0).
Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingMaxPool2d >>> pool = SpikingMaxPool2d(kernel_size=2, stride=2) >>> x = mx.ones((4, 8, 8, 16)) >>> out = pool(x) >>> out.shape [4, 4, 4, 16]
__call__ ¶
Apply max pooling.
Args:
x: Input tensor (B, H, W, C) in NHWC format.
Returns:
Pooled tensor (B, H', W', C).
SpikingAvgPool2d¶
Average pooling for spike feature maps.
SpikingAvgPool2d ¶
Bases: Module
Average pooling for spiking networks.
Wraps mx.nn.AvgPool2d. When applied to binary spike tensors
the output represents the local firing density within each pooling
window.
Input and output use NHWC layout: (B, H, W, C).
Args:
kernel_size: Size of the pooling window (int or tuple).
stride: Stride of the pooling window. Defaults to
kernel_size when None.
padding: Zero-padding added to both sides (default 0).
Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingAvgPool2d >>> pool = SpikingAvgPool2d(kernel_size=2, stride=2) >>> x = mx.ones((4, 8, 8, 16)) >>> out = pool(x) >>> out.shape [4, 4, 4, 16]
__call__ ¶
Apply average pooling.
Args:
x: Input tensor (B, H, W, C) in NHWC format.
Returns:
Pooled tensor (B, H', W', C).
SpikingFlatten¶
Flatten spatial dimensions for transition from conv to FC layers.
SpikingFlatten ¶
Bases: Module
Flatten spatial dimensions of a spiking feature map.
Reshapes (B, H, W, C) to (B, H*W*C). This is a stateless
layer that carries no learnable parameters.
Args: start_dim: First dimension to flatten (default 1, preserving the batch dimension).
Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikingFlatten >>> flat = SpikingFlatten() >>> x = mx.ones((4, 8, 8, 16)) >>> out = flat(x) >>> out.shape [4, 1024]
__call__ ¶
Flatten the tensor from start_dim onward.
Args: x: Input tensor of arbitrary shape.
Returns:
Flattened tensor with all dims from start_dim merged.
SpikeDropout¶
Dropout specialized for binary spike trains. Drops spikes with probability p during training (no rescaling, since spikes are binary).
SpikeDropout ¶
Bases: Module
Drops spikes with probability p during training.
During evaluation (model.eval()), all spikes pass through
unchanged. Unlike nn.Dropout, no rescaling is applied:
spikes are binary, so rescaling would produce non-binary values.
Args: p: Probability of dropping each spike. Must be in [0, 1).
Examples: >>> import mlx.core as mx >>> from mlxsnn.layers import SpikeDropout >>> drop = SpikeDropout(p=0.3) >>> drop.train() >>> x = mx.ones((4, 10)) >>> out = drop(x) >>> out.shape [4, 10]
__call__ ¶
Apply spike dropout.
Args: x: Spike tensor of any shape (typically binary {0, 1}).
Returns: Tensor with spikes randomly zeroed during training.
Neuron Factory¶
create_neuron ¶
Create a spiking neuron from a type string or return an existing instance.
Args:
neuron_type: A string identifier ('leaky', 'if',
'synaptic', 'alpha', 'alif', 'izhikevich')
or an already-constructed SpikingNeuron.
**kwargs: Keyword arguments forwarded to the neuron constructor
when neuron_type is a string.
Returns:
A SpikingNeuron instance.
Raises:
ValueError: If neuron_type is an unrecognised string.
TypeError: If neuron_type is neither a string nor a
SpikingNeuron.