Skip to content

Functional API

Stateless pure functions for neuron dynamics, spike operations, and loss functions. These are useful for custom training loops and advanced research.

Neuron Dynamics

lif_step

lif_step(x: array, mem: array, beta: float = 0.9, threshold: float = 1.0, reset_mechanism: str = 'subtract', surrogate_fn: str = 'fast_sigmoid', surrogate_scale: float = 25.0) -> tuple[mx.array, mx.array]

Single timestep of Leaky Integrate-and-Fire dynamics.

Membrane dynamics: mem[t+1] = beta * mem[t] + x[t+1] - spk[t] * threshold

Args: x: Input current [batch, features]. mem: Previous membrane potential [batch, features]. beta: Membrane decay factor. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale for surrogate gradient.

Returns: Tuple of (spikes, new_membrane_potential).

if_step

if_step(x: array, mem: array, threshold: float = 1.0, reset_mechanism: str = 'subtract', surrogate_fn: str = 'fast_sigmoid', surrogate_scale: float = 25.0) -> tuple[mx.array, mx.array]

Single timestep of Integrate-and-Fire dynamics (no leak).

Membrane dynamics: mem[t+1] = mem[t] + x[t+1] - spk[t] * threshold

Args: x: Input current [batch, features]. mem: Previous membrane potential [batch, features]. threshold: Spike threshold. reset_mechanism: 'subtract', 'zero', or 'none'. surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale for surrogate gradient.

Returns: Tuple of (spikes, new_membrane_potential).

Spike Operations

fire

fire(mem: array, threshold: float, surrogate_fn: Callable[[array], array]) -> mx.array

Generate spikes using surrogate gradient function.

Args: mem: Membrane potential. threshold: Spike threshold. surrogate_fn: Surrogate gradient function (Heaviside forward, smooth backward).

Returns: Binary spike array.

reset_subtract

reset_subtract(mem: array, spk: array, threshold: float) -> mx.array

Subtract-reset: reduce membrane by threshold where spike occurred.

Args: mem: Membrane potential before reset. spk: Binary spike array. threshold: Value to subtract.

Returns: Membrane potential after reset.

reset_zero

reset_zero(mem: array, spk: array) -> mx.array

Zero-reset: set membrane to zero where spike occurred.

Args: mem: Membrane potential before reset. spk: Binary spike array.

Returns: Membrane potential after reset.

Loss Functions

Classification Losses

ce_rate_loss

ce_rate_loss(spk_out: array, targets: array) -> mx.array

Cross-entropy loss on mean spike rate.

Averages spikes over time to get a firing rate, then applies softmax cross-entropy against integer class labels.

Args: spk_out: Output spikes [T, batch, num_classes]. targets: Integer class labels [batch].

Returns: Scalar loss value.

ce_count_loss

ce_count_loss(spk_out: array, targets: array) -> mx.array

Cross-entropy loss on total spike count.

Sums spikes over time, then applies softmax cross-entropy against integer class labels.

Args: spk_out: Output spikes [T, batch, num_classes]. targets: Integer class labels [batch].

Returns: Scalar loss value.

mse_count_loss

mse_count_loss(spk_out: array, targets: array) -> mx.array

MSE loss between spike counts and target counts.

Useful for regression or when target firing rates are known.

Args: spk_out: Output spikes [T, batch, num_classes]. targets: Target spike counts [batch, num_classes].

Returns: Scalar loss value.

mse_membrane_loss

mse_membrane_loss(mem: array, targets: array, on_target: float = 1.0, off_target: float = 0.0) -> mx.array

MSE loss on membrane potential with one-hot target encoding.

Creates a target tensor where the correct class has value on_target and all other classes have off_target, then computes the mean squared error against the membrane potential.

Args: mem: Membrane potential [batch, num_classes] at the last timestep. targets: Integer class labels [batch]. on_target: Target value for the correct class. off_target: Target value for incorrect classes.

Returns: Scalar loss value.

membrane_loss

membrane_loss(mem: array, targets: array) -> mx.array

Cross-entropy loss on final-timestep membrane potential.

Uses the membrane potential at the last timestep as logits for classification.

Args: mem: Membrane potentials [T, batch, num_classes]. targets: Integer class labels [batch].

Returns: Scalar loss value.

rate_coding_loss

rate_coding_loss(spk_out: array, targets: array) -> mx.array

Cross-entropy loss on spike counts (rate coding).

Sums spikes across time to get a firing rate, then applies softmax cross-entropy against target labels.

Args: spk_out: Output spikes [T, batch, num_classes]. targets: Integer class labels [batch].

Returns: Scalar loss value.

Regularization Losses

activity_reg_loss

activity_reg_loss(spk_out: array, target_rate: float = 0.1) -> mx.array

MSE between mean firing rate and a target rate.

Encourages the network to maintain a desired average firing rate, preventing dead neurons (rate = 0) or saturation (rate = 1).

Args: spk_out: Output spikes [T, batch, ...]. target_rate: Desired mean firing rate in [0, 1].

Returns: Scalar loss value.

Examples: >>> import mlx.core as mx >>> spk = mx.ones((10, 4, 5)) * 0.5 >>> loss = activity_reg_loss(spk, target_rate=0.1)

l1_spike_loss

l1_spike_loss(spk_out: array) -> mx.array

L1 penalty on spike counts (sparsity regularization).

Penalizes the total number of spikes, encouraging sparse spiking activity across the network.

Args: spk_out: Output spikes [T, batch, ...].

Returns: Scalar loss value.

Examples: >>> import mlx.core as mx >>> spk = mx.ones((10, 4, 5)) >>> loss = l1_spike_loss(spk)

l2_spike_loss

l2_spike_loss(spk_out: array) -> mx.array

L2 penalty on spike counts.

Penalizes the squared spike count per neuron, providing a smoother gradient than L1 for firing rate control.

Args: spk_out: Output spikes [T, batch, ...].

Returns: Scalar loss value.

Examples: >>> import mlx.core as mx >>> spk = mx.ones((10, 4, 5)) >>> loss = l2_spike_loss(spk)

Metrics

spike_rate

spike_rate(spk_out: array) -> mx.array

Compute mean firing rate over time.

Args: spk_out: Output spikes [T, batch, ...].

Returns: Mean firing rate [batch, ...].

spike_count

spike_count(spk_out: array) -> mx.array

Compute total spike count over time.

Args: spk_out: Output spikes [T, batch, ...].

Returns: Total spike count [batch, ...].