Functional API¶
Stateless pure functions for neuron dynamics, spike operations, and loss functions. These are useful for custom training loops and advanced research.
Neuron Dynamics¶
lif_step ¶
lif_step(x: array, mem: array, beta: float = 0.9, threshold: float = 1.0, reset_mechanism: str = 'subtract', surrogate_fn: str = 'fast_sigmoid', surrogate_scale: float = 25.0) -> tuple[mx.array, mx.array]
Single timestep of Leaky Integrate-and-Fire dynamics.
Membrane dynamics: mem[t+1] = beta * mem[t] + x[t+1] - spk[t] * threshold
Args:
x: Input current [batch, features].
mem: Previous membrane potential [batch, features].
beta: Membrane decay factor.
threshold: Spike threshold.
reset_mechanism: 'subtract', 'zero', or 'none'.
surrogate_fn: Surrogate gradient function name or callable.
surrogate_scale: Scale for surrogate gradient.
Returns: Tuple of (spikes, new_membrane_potential).
if_step ¶
if_step(x: array, mem: array, threshold: float = 1.0, reset_mechanism: str = 'subtract', surrogate_fn: str = 'fast_sigmoid', surrogate_scale: float = 25.0) -> tuple[mx.array, mx.array]
Single timestep of Integrate-and-Fire dynamics (no leak).
Membrane dynamics: mem[t+1] = mem[t] + x[t+1] - spk[t] * threshold
Args:
x: Input current [batch, features].
mem: Previous membrane potential [batch, features].
threshold: Spike threshold.
reset_mechanism: 'subtract', 'zero', or 'none'.
surrogate_fn: Surrogate gradient function name or callable.
surrogate_scale: Scale for surrogate gradient.
Returns: Tuple of (spikes, new_membrane_potential).
Spike Operations¶
fire ¶
Generate spikes using surrogate gradient function.
Args: mem: Membrane potential. threshold: Spike threshold. surrogate_fn: Surrogate gradient function (Heaviside forward, smooth backward).
Returns: Binary spike array.
reset_subtract ¶
Subtract-reset: reduce membrane by threshold where spike occurred.
Args: mem: Membrane potential before reset. spk: Binary spike array. threshold: Value to subtract.
Returns: Membrane potential after reset.
reset_zero ¶
Zero-reset: set membrane to zero where spike occurred.
Args: mem: Membrane potential before reset. spk: Binary spike array.
Returns: Membrane potential after reset.
Loss Functions¶
Classification Losses¶
ce_rate_loss ¶
Cross-entropy loss on mean spike rate.
Averages spikes over time to get a firing rate, then applies softmax cross-entropy against integer class labels.
Args:
spk_out: Output spikes [T, batch, num_classes].
targets: Integer class labels [batch].
Returns: Scalar loss value.
ce_count_loss ¶
Cross-entropy loss on total spike count.
Sums spikes over time, then applies softmax cross-entropy against integer class labels.
Args:
spk_out: Output spikes [T, batch, num_classes].
targets: Integer class labels [batch].
Returns: Scalar loss value.
mse_count_loss ¶
MSE loss between spike counts and target counts.
Useful for regression or when target firing rates are known.
Args:
spk_out: Output spikes [T, batch, num_classes].
targets: Target spike counts [batch, num_classes].
Returns: Scalar loss value.
mse_membrane_loss ¶
mse_membrane_loss(mem: array, targets: array, on_target: float = 1.0, off_target: float = 0.0) -> mx.array
MSE loss on membrane potential with one-hot target encoding.
Creates a target tensor where the correct class has value
on_target and all other classes have off_target, then
computes the mean squared error against the membrane potential.
Args:
mem: Membrane potential [batch, num_classes] at the last
timestep.
targets: Integer class labels [batch].
on_target: Target value for the correct class.
off_target: Target value for incorrect classes.
Returns: Scalar loss value.
membrane_loss ¶
Cross-entropy loss on final-timestep membrane potential.
Uses the membrane potential at the last timestep as logits for classification.
Args:
mem: Membrane potentials [T, batch, num_classes].
targets: Integer class labels [batch].
Returns: Scalar loss value.
rate_coding_loss ¶
Cross-entropy loss on spike counts (rate coding).
Sums spikes across time to get a firing rate, then applies softmax cross-entropy against target labels.
Args:
spk_out: Output spikes [T, batch, num_classes].
targets: Integer class labels [batch].
Returns: Scalar loss value.
Regularization Losses¶
activity_reg_loss ¶
MSE between mean firing rate and a target rate.
Encourages the network to maintain a desired average firing rate, preventing dead neurons (rate = 0) or saturation (rate = 1).
Args:
spk_out: Output spikes [T, batch, ...].
target_rate: Desired mean firing rate in [0, 1].
Returns: Scalar loss value.
Examples: >>> import mlx.core as mx >>> spk = mx.ones((10, 4, 5)) * 0.5 >>> loss = activity_reg_loss(spk, target_rate=0.1)
l1_spike_loss ¶
L1 penalty on spike counts (sparsity regularization).
Penalizes the total number of spikes, encouraging sparse spiking activity across the network.
Args:
spk_out: Output spikes [T, batch, ...].
Returns: Scalar loss value.
Examples: >>> import mlx.core as mx >>> spk = mx.ones((10, 4, 5)) >>> loss = l1_spike_loss(spk)
l2_spike_loss ¶
L2 penalty on spike counts.
Penalizes the squared spike count per neuron, providing a smoother gradient than L1 for firing rate control.
Args:
spk_out: Output spikes [T, batch, ...].
Returns: Scalar loss value.
Examples: >>> import mlx.core as mx >>> spk = mx.ones((10, 4, 5)) >>> loss = l2_spike_loss(spk)
Metrics¶
spike_rate ¶
Compute mean firing rate over time.
Args:
spk_out: Output spikes [T, batch, ...].
Returns:
Mean firing rate [batch, ...].
spike_count ¶
Compute total spike count over time.
Args:
spk_out: Output spikes [T, batch, ...].
Returns:
Total spike count [batch, ...].