Skip to content

Operators

Mathematically-principled operators that optimize convolutional SNN processing. These operators exploit temporal structure to reduce computation while preserving gradient flow.

Temporal Aggregated Convolution (TAC)

Exploits the linearity of convolution to aggregate K timesteps into a single conv call, achieving near-linear speedup.

TemporalAggregatedConv

Bases: Module

Temporal Aggregated Convolution layer.

Reduces the number of convolution calls by aggregating K consecutive input timesteps with exponential weighting, then applying a single spatial conv per chunk. Output has T/K timesteps.

When K=1, this is equivalent to standard Conv+LIF. When K=T, this collapses the entire sequence into a single conv call (exact for output layers with reset_mechanism='none').

Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor (0 < beta < 1). chunk_size: Number of timesteps K to aggregate per conv call. If None, defaults to 1 (standard per-timestep conv). stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.

Examples: >>> import mlx.core as mx >>> layer = TemporalAggregatedConv(2, 16, 3, beta=0.9, chunk_size=4, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape # (T/K, B, H', W', C_out) [4, 2, 8, 8, 16]

init_state

init_state(batch_size: int, spatial_shape: tuple) -> dict

Initialize neuron state.

Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.

Returns: State dict with 'mem' of shape (B, H', W', C_out).

__call__

__call__(x_seq: array, state: dict) -> tuple[mx.array, dict]

Forward pass over full temporal sequence.

Aggregates every K timesteps and applies a single conv per chunk.

Args: x_seq: Input sequence (T, B, H, W, C_in). state: Neuron state dict from previous call.

Returns: Tuple of (spk_seq, new_state) where spk_seq has shape (T/K, B, H', W', C_out). If T is not divisible by K, the remaining timesteps form a smaller final chunk.

theoretical_error_bound

theoretical_error_bound(firing_rate: float) -> float

Compute theoretical per-neuron error bound.

Args: firing_rate: Average firing rate rho in [0, 1].

Returns: Upper bound on |U_exact - U_TAC| per neuron.

TAC with Temporal Preservation (TAC-TP)

TAC variant that preserves the full temporal dimension in the output.

TACTemporalPreserve

Bases: Module

TAC with Temporal Preservation.

Aggregates K input frames for a single conv call but preserves all T output timesteps by running LIF dynamics K times per chunk with the shared conv output.

When K=1, this is exactly standard Conv+LIF.

Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor (0 < beta < 1). chunk_size: Number of timesteps K to aggregate per conv call. stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.

Examples: >>> import mlx.core as mx >>> layer = TACTemporalPreserve(2, 16, 3, beta=0.9, chunk_size=4, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape # (T, B, H', W', C_out) — T preserved! [16, 2, 8, 8, 16]

init_state

init_state(batch_size: int, spatial_shape: tuple) -> dict

Initialize neuron state.

__call__

__call__(x_seq: array, state: dict) -> tuple[mx.array, dict]

Forward pass preserving full temporal dimension.

Args: x_seq: Input sequence (T, B, H, W, C_in). state: Neuron state dict.

Returns: Tuple of (spk_seq, new_state) where spk_seq has shape (T, B, H', W', C_out) — same T as input.

conv_calls

conv_calls(T: int) -> int

Number of conv calls for a sequence of length T.

Learnable TAC (L-TAC)

TAC with learnable aggregation weights instead of uniform averaging.

LearnableTAC

Bases: Module

Learnable Temporal Aggregated Convolution.

Like standard TAC but with learnable aggregation weights per layer, initialized to exponential decay beta^{K-1-k}.

Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor (0 < beta < 1). chunk_size: Number of timesteps K to aggregate per conv call. preserve_temporal: If True, use TAC-TP mode (output T timesteps). If False, use standard TAC mode (output T/K timesteps). stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.

Examples: >>> import mlx.core as mx >>> layer = LearnableTAC(2, 16, 3, beta=0.9, chunk_size=4, ... preserve_temporal=True, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape # preserve_temporal=True → T preserved [16, 2, 8, 8, 16]

agg_weights property

agg_weights: array

Current aggregation weights (softmax of logits).

init_state

init_state(batch_size: int, spatial_shape: tuple) -> dict

Initialize neuron state.

__call__

__call__(x_seq: array, state: dict) -> tuple[mx.array, dict]

Forward pass.

If preserve_temporal=True (TAC-TP mode): output has same T as input. If preserve_temporal=False (standard TAC mode): output has T/K timesteps.

conv_calls

conv_calls(T: int) -> int

Number of conv calls for a sequence of length T.

Fourier Temporal Convolution (FTC)

Learnable biquad IIR filters per channel for temporal processing.

FourierTemporalConv

Bases: Module

Fourier Temporal Convolution with learnable biquad IIR filters.

Applies spatial Conv2d followed by per-channel second-order IIR temporal filtering and spiking nonlinearity. Each output channel has its own learnable temporal transfer function (4 parameters: r, theta, b0, b1).

When filter_order=1, reduces to standard Conv+LIF with learnable beta.

Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. filter_order: 1 for standard LIF, 2 for biquad IIR. stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.

Examples: >>> import mlx.core as mx >>> layer = FourierTemporalConv(2, 16, 3, filter_order=2, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((10, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape [10, 2, 8, 8, 16]

init_state

init_state(batch_size: int, spatial_shape: tuple) -> dict

Initialize FTC state.

Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.

Returns: State dict with 'mem' and 'mem_prev' (for second-order), and 'input_prev' for feedforward delay.

__call__

__call__(x_seq: array, state: dict) -> tuple[mx.array, dict]

Forward pass over full temporal sequence.

Args: x_seq: Input sequence (T, B, H, W, C_in). state: State dict from previous call.

Returns: Tuple of (spk_seq, new_state) both with T timesteps.

get_frequency_response

get_frequency_response(num_points: int = 256) -> tuple

Compute the magnitude frequency response for each channel.

Args: num_points: Number of frequency points.

Returns: (freqs, magnitudes) where freqs is (num_points,) in [0, pi] and magnitudes is (C_out, num_points).

Information-Max Spike Convolution (IMC)

Information-theoretic channel gating that selectively activates channels based on input information content.

InfoMaxSpikeConv

Bases: Module

Information-Maximizing Spike Convolution layer.

Standard Conv2d + spiking neuron with per-channel learnable gates and information-theoretic regularization. The gates alpha_c modulate output spikes per channel, enabling differentiable channel pruning during training.

Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor for the LIF neuron. info_reg_weight: Lambda for information bottleneck regularization. gate_reg_weight: L1 regularization weight on channel gates. ema_decay: Decay factor for running firing rate estimation. stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.

Examples: >>> import mlx.core as mx >>> layer = InfoMaxSpikeConv(2, 16, 3, info_reg_weight=0.01, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x = mx.ones((2, 8, 8, 2)) >>> spk, state = layer(x, state) >>> loss = layer.info_loss()

init_state

init_state(batch_size: int, spatial_shape: tuple) -> dict

Initialize neuron state.

Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.

Returns: State dict with 'mem' of shape (B, H', W', C_out).

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep with channel gating.

Args: x: Input tensor (B, H, W, C_in). state: Neuron state dict.

Returns: Tuple of (gated_spikes, new_state).

info_loss

info_loss() -> mx.array

Information bottleneck regularization loss.

Penalizes the layer if output information capacity is below the input information content. Also applies L1 regularization on channel gates to encourage sparsity.

Returns: Scalar loss value.

effective_channels

effective_channels(threshold: float = 0.5) -> int

Count channels with gate > threshold.

Args: threshold: Gate threshold for considering a channel active.

Returns: Number of active channels.

channel_utilization

channel_utilization() -> dict

Get channel utilization statistics.

Returns: Dict with gate stats, firing rates, and entropy values.

minimum_channels staticmethod

minimum_channels(c_in: int, rho_in: float, rho_out: float, spatial_reduction: float = 1.0) -> int

Compute minimum channels for information preservation.

Args: c_in: Input channels. rho_in: Input firing rate. rho_out: Target output firing rate. spatial_reduction: H_inW_in / (H_outW_out).

Returns: Minimum number of output channels.

Temporal Collapse Convolution (TCC)

Sparsity-aware operator that collapses consecutive silent timesteps.

TemporalCollapseConv

Bases: Module

Temporal Collapse Convolution layer.

Processes a full temporal sequence, collapsing consecutive no-spike (or low-spike) timesteps into single conv calls. Uses cached spike density from the previous forward pass to plan the collapse.

Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Spatial convolution kernel size. beta: Membrane decay factor. collapse_threshold: Maximum spike density to consider a timestep as "no-spike" for collapse. Default 0.02 (2% of neurons). stride: Convolution stride. padding: Convolution padding. bias: If True, add learnable bias to convolution. bn: If True, apply batch normalization. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name. surrogate_scale: Scale for surrogate gradient.

Examples: >>> import mlx.core as mx >>> layer = TemporalCollapseConv(2, 16, 3, beta=0.9, padding=1) >>> state = layer.init_state(batch_size=2, spatial_shape=(8, 8)) >>> x_seq = mx.ones((16, 2, 8, 8, 2)) # (T, B, H, W, C) >>> spk_seq, state = layer(x_seq, state) >>> spk_seq.shape [16, 2, 8, 8, 16]

conv_calls property

conv_calls: int

Number of conv calls in the last forward pass.

init_state

init_state(batch_size: int, spatial_shape: tuple) -> dict

Initialize neuron state.

Args: batch_size: Batch size. spatial_shape: (H', W') output spatial dims after conv.

Returns: State dict with 'mem' of shape (B, H', W', C_out).

__call__

__call__(x_seq: array, state: dict) -> tuple[mx.array, dict]

Forward pass with temporal collapse.

Collapses consecutive low-activity timesteps into single conv calls. Always produces T output timesteps (collapsed timesteps produce a single spike output replicated across merged steps).

Args: x_seq: Input sequence (T, B, H, W, C_in). state: Neuron state dict.

Returns: Tuple of (spk_seq, new_state) with shape (T, B, H', W', C_out).

collapse_stats

collapse_stats() -> dict

Get statistics about the last forward pass collapse.

Returns: Dict with conv_calls, theoretical_calls (T), and speedup.