Operators¶
Mathematically-principled operators that optimize convolutional SNN processing. These operators exploit temporal structure to reduce computation while preserving gradient flow.
Temporal Aggregated Convolution (TAC)¶
Exploits the linearity of convolution to aggregate K timesteps into a single conv call, achieving near-linear speedup.
TemporalAggregatedConv ¶
Bases: Module
Temporal Aggregated Convolution layer.
Reduces the number of convolution calls by aggregating K consecutive input timesteps with exponential weighting, then applying a single spatial conv per chunk. Output has T/K timesteps.
When K=1, this is equivalent to standard Conv+LIF. When K=T, this collapses the entire sequence into a single conv call (exact for output layers with reset_mechanism='none').
Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor (0 < beta < 1). chunk_size: Number of timesteps K to aggregate per conv call. If None, defaults to 1 (standard per-timestep conv). stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.
Examples: >>> import mlx.core as mx >>> layer = TemporalAggregatedConv(2, 16, 3, beta=0.9, chunk_size=4, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape # (T/K, B, H', W', C_out) [4, 2, 8, 8, 16]
init_state ¶
Initialize neuron state.
Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.
Returns: State dict with 'mem' of shape (B, H', W', C_out).
__call__ ¶
Forward pass over full temporal sequence.
Aggregates every K timesteps and applies a single conv per chunk.
Args: x_seq: Input sequence (T, B, H, W, C_in). state: Neuron state dict from previous call.
Returns: Tuple of (spk_seq, new_state) where spk_seq has shape (T/K, B, H', W', C_out). If T is not divisible by K, the remaining timesteps form a smaller final chunk.
theoretical_error_bound ¶
Compute theoretical per-neuron error bound.
Args: firing_rate: Average firing rate rho in [0, 1].
Returns: Upper bound on |U_exact - U_TAC| per neuron.
TAC with Temporal Preservation (TAC-TP)¶
TAC variant that preserves the full temporal dimension in the output.
TACTemporalPreserve ¶
Bases: Module
TAC with Temporal Preservation.
Aggregates K input frames for a single conv call but preserves all T output timesteps by running LIF dynamics K times per chunk with the shared conv output.
When K=1, this is exactly standard Conv+LIF.
Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor (0 < beta < 1). chunk_size: Number of timesteps K to aggregate per conv call. stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.
Examples: >>> import mlx.core as mx >>> layer = TACTemporalPreserve(2, 16, 3, beta=0.9, chunk_size=4, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape # (T, B, H', W', C_out) — T preserved! [16, 2, 8, 8, 16]
__call__ ¶
Forward pass preserving full temporal dimension.
Args: x_seq: Input sequence (T, B, H, W, C_in). state: Neuron state dict.
Returns: Tuple of (spk_seq, new_state) where spk_seq has shape (T, B, H', W', C_out) — same T as input.
Learnable TAC (L-TAC)¶
TAC with learnable aggregation weights instead of uniform averaging.
LearnableTAC ¶
Bases: Module
Learnable Temporal Aggregated Convolution.
Like standard TAC but with learnable aggregation weights per layer, initialized to exponential decay beta^{K-1-k}.
Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor (0 < beta < 1). chunk_size: Number of timesteps K to aggregate per conv call. preserve_temporal: If True, use TAC-TP mode (output T timesteps). If False, use standard TAC mode (output T/K timesteps). stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.
Examples: >>> import mlx.core as mx >>> layer = LearnableTAC(2, 16, 3, beta=0.9, chunk_size=4, ... preserve_temporal=True, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape # preserve_temporal=True → T preserved [16, 2, 8, 8, 16]
__call__ ¶
Forward pass.
If preserve_temporal=True (TAC-TP mode): output has same T as input. If preserve_temporal=False (standard TAC mode): output has T/K timesteps.
Fourier Temporal Convolution (FTC)¶
Learnable biquad IIR filters per channel for temporal processing.
FourierTemporalConv ¶
Bases: Module
Fourier Temporal Convolution with learnable biquad IIR filters.
Applies spatial Conv2d followed by per-channel second-order IIR temporal filtering and spiking nonlinearity. Each output channel has its own learnable temporal transfer function (4 parameters: r, theta, b0, b1).
When filter_order=1, reduces to standard Conv+LIF with learnable beta.
Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. filter_order: 1 for standard LIF, 2 for biquad IIR. stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.
Examples: >>> import mlx.core as mx >>> layer = FourierTemporalConv(2, 16, 3, filter_order=2, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((10, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape [10, 2, 8, 8, 16]
init_state ¶
Initialize FTC state.
Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.
Returns: State dict with 'mem' and 'mem_prev' (for second-order), and 'input_prev' for feedforward delay.
__call__ ¶
Forward pass over full temporal sequence.
Args: x_seq: Input sequence (T, B, H, W, C_in). state: State dict from previous call.
Returns: Tuple of (spk_seq, new_state) both with T timesteps.
get_frequency_response ¶
Compute the magnitude frequency response for each channel.
Args: num_points: Number of frequency points.
Returns: (freqs, magnitudes) where freqs is (num_points,) in [0, pi] and magnitudes is (C_out, num_points).
Information-Max Spike Convolution (IMC)¶
Information-theoretic channel gating that selectively activates channels based on input information content.
InfoMaxSpikeConv ¶
Bases: Module
Information-Maximizing Spike Convolution layer.
Standard Conv2d + spiking neuron with per-channel learnable gates and information-theoretic regularization. The gates alpha_c modulate output spikes per channel, enabling differentiable channel pruning during training.
Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor for the LIF neuron. info_reg_weight: Lambda for information bottleneck regularization. gate_reg_weight: L1 regularization weight on channel gates. ema_decay: Decay factor for running firing rate estimation. stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.
Examples: >>> import mlx.core as mx >>> layer = InfoMaxSpikeConv(2, 16, 3, info_reg_weight=0.01, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x = mx.ones((2, 8, 8, 2)) >>> spk, state = layer(x, state) >>> loss = layer.info_loss()
init_state ¶
Initialize neuron state.
Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.
Returns: State dict with 'mem' of shape (B, H', W', C_out).
__call__ ¶
Forward one timestep with channel gating.
Args: x: Input tensor (B, H, W, C_in). state: Neuron state dict.
Returns: Tuple of (gated_spikes, new_state).
info_loss ¶
Information bottleneck regularization loss.
Penalizes the layer if output information capacity is below the input information content. Also applies L1 regularization on channel gates to encourage sparsity.
Returns: Scalar loss value.
effective_channels ¶
Count channels with gate > threshold.
Args: threshold: Gate threshold for considering a channel active.
Returns: Number of active channels.
channel_utilization ¶
Get channel utilization statistics.
Returns: Dict with gate stats, firing rates, and entropy values.
minimum_channels
staticmethod
¶
Compute minimum channels for information preservation.
Args: c_in: Input channels. rho_in: Input firing rate. rho_out: Target output firing rate. spatial_reduction: H_inW_in / (H_outW_out).
Returns: Minimum number of output channels.
Temporal Collapse Convolution (TCC)¶
Sparsity-aware operator that collapses consecutive silent timesteps.
TemporalCollapseConv ¶
Bases: Module
Temporal Collapse Convolution layer.
Processes a full temporal sequence, collapsing consecutive no-spike (or low-spike) timesteps into single conv calls. Uses cached spike density from the previous forward pass to plan the collapse.
Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor. collapse_threshold: Maximum spike density to consider a timestep as "no-spike" for collapse. Default 0.02 (2% of neurons). stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.
Examples: >>> import mlx.core as mx >>> layer = TemporalCollapseConv(2, 16, 3, beta=0.9, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape [16, 2, 8, 8, 16]
init_state ¶
Initialize neuron state.
Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.
Returns: State dict with 'mem' of shape (B, H', W', C_out).
__call__ ¶
Forward pass with temporal collapse.
Collapses consecutive low-activity timesteps into single conv calls. Always produces T output timesteps (collapsed timesteps produce a single spike output replicated across merged steps).
Args: x_seq: Input sequence (T, B, H, W, C_in). state: Neuron state dict.
Returns: Tuple of (spk_seq, new_state) with shape (T, B, H', W', C_out).
collapse_stats ¶
Get statistics about the last forward pass collapse.
Returns: Dict with conv_calls, theoretical_calls (T), and speedup.