Skip to content

Neurons

Spiking neuron models are the core building blocks of SNNs. All neurons inherit from SpikingNeuron and follow the same interface: (x, state) → (spk, new_state).

Base Class

SpikingNeuron

Bases: Module

Abstract base class for spiking neuron models.

Args: threshold: Membrane potential threshold for spike generation. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: How to reset membrane after a spike. 'subtract' — subtract threshold from membrane potential. 'zero' — reset membrane to zero. 'none' — no reset (useful for output layers). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for the surrogate gradient.

Examples: Subclasses must implement init_state and __call__::

    class MyNeuron(SpikingNeuron):
        def init_state(self, batch_size, features):
            return {"mem": mx.zeros((batch_size, features))}

        def __call__(self, x, state):
            mem = state["mem"] + x
            spk = self.fire(mem)
            mem = self.reset(mem, spk)
            return spk, {"mem": mem}

__init__

__init__(threshold: float = 1.0, learn_threshold: bool = False, reset_mechanism: str = 'subtract', surrogate_fn: str = 'arctan', surrogate_scale: float = 2.0)

init_state

init_state(batch_size: int, *args) -> dict

Initialize neuron hidden state.

Args: batch_size: Number of samples in the batch. *args: Additional shape dimensions (e.g., feature size).

Returns: A dictionary of state tensors initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward pass: compute spikes from input and previous state.

Args: x: Input current of shape [batch, features]. state: Dictionary of state tensors from the previous timestep.

Returns: A tuple of (spikes, new_state) where spikes is a binary array of the same shape as x.

fire

fire(mem: array) -> mx.array

Generate spikes using surrogate gradient.

In the forward pass this applies the Heaviside step function. In the backward pass the surrogate gradient is used.

Args: mem: Membrane potential array.

Returns: Binary spike array (1 where mem >= threshold, 0 otherwise).

reset

reset(mem: array, spk: array) -> mx.array

Apply reset mechanism after spike generation.

Args: mem: Membrane potential before reset. spk: Binary spike array from fire().

Returns: Membrane potential after reset.

Leaky Integrate-and-Fire (LIF)

The most commonly used spiking neuron. Membrane dynamics:

\[U[t+1] = \beta \cdot U[t] + X[t+1] - S[t] \cdot V_{thr}\]

Leaky

Bases: SpikingNeuron

Leaky Integrate-and-Fire neuron.

Args: beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. learn_beta: If True, beta becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.

Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import Leaky >>> lif = Leaky(beta=0.9) >>> state = lif.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = lif(x, state)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize LIF neuron state.

Args: batch_size: Number of samples in the batch. *shape: Feature dimensions. Can be a single int for FC layers (e.g., init_state(B, 128)) or spatial dims for conv layers (e.g., init_state(B, 64, 64, 128)).

Returns: State dict with 'mem' initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Args: x: Input current [batch, features]. state: Dict with 'mem' from previous timestep.

Returns: Tuple of (spikes, new_state).

Integrate-and-Fire (IF)

Non-leaky variant — no membrane decay (\(\beta = 1\)).

IF

Bases: SpikingNeuron

Integrate-and-Fire neuron (no leak).

This is the simplest spiking neuron model. The membrane potential integrates input current without decay and fires when threshold is reached.

Args: threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.

Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import IF >>> neuron = IF(threshold=1.0) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) * 0.5 >>> spk, state = neuron(x, state) # No spike yet (mem=0.5) >>> spk, state = neuron(x, state) # Spike! (mem=1.0 >= threshold)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize IF neuron state.

Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).

Returns: State dict with 'mem' initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Args: x: Input current [batch, features]. state: Dict with 'mem' from previous timestep.

Returns: Tuple of (spikes, new_state).

Izhikevich

2D dynamical system with biologically realistic spiking patterns.

\[\frac{dv}{dt} = 0.04v^2 + 5v + 140 - u + I$$ $$\frac{du}{dt} = a(bv - u)\]

Supports presets: Regular Spiking (RS), Intrinsically Bursting (IB), Chattering (CH), Fast Spiking (FS).

Izhikevich

Bases: SpikingNeuron

Izhikevich spiking neuron with two-dimensional dynamics.

This model uses two coupled variables (v, u) to reproduce a rich repertoire of neural spiking behaviours while remaining computationally efficient.

Parameters:

Name Type Description Default
a float

Time scale of the recovery variable u. Smaller values result in slower recovery.

0.02
b float

Sensitivity of the recovery variable u to sub-threshold fluctuations of the membrane potential v.

0.2
c float

After-spike reset value of v (in mV).

-65.0
d float

After-spike increment applied to u.

8.0
dt float

Integration timestep. The default of 0.5 ms provides a good balance between accuracy and speed; use smaller values for higher fidelity.

0.5
preset str or None

If given, must be a key in PRESETS (one of 'RS', 'IB', 'CH', 'FS'). When a preset is specified the (a, b, c, d) values are loaded from the preset and any explicitly passed values for those parameters are ignored.

None
surrogate_fn str

Surrogate gradient function name.

'arctan'
surrogate_scale float

Scale parameter for the surrogate gradient.

0.1

Examples:

>>> import mlx.core as mx
>>> from mlxsnn.neurons.izhikevich import Izhikevich
>>> neuron = Izhikevich(preset="RS")
>>> state = neuron.init_state(4, 128)
>>> x = mx.ones((4, 128)) * 10.0
>>> spk, state = neuron(x, state)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize Izhikevich neuron state.

Parameters:

Name Type Description Default
batch_size int

Number of samples in the batch.

required
*shape

Feature dimensions (single int or spatial dims).

()

Returns:

Type Description
dict

State dictionary with keys 'v' (membrane potential, initialised to -65.0 mV) and 'u' (recovery variable, initialised to b * v).

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep of the Izhikevich model.

The membrane potential is integrated using a forward-Euler step with step size dt. When v >= 30 a spike is emitted, v is reset to c, and u is incremented by d.

Parameters:

Name Type Description Default
x array

Input current of shape [batch, features].

required
state dict

Dictionary with 'v' and 'u' from the previous timestep.

required

Returns:

Type Description
tuple[array, dict]

A tuple (spk, new_state) where spk is a binary spike array and new_state contains the updated 'v' and 'u'.

Adaptive LIF (ALIF)

LIF with adaptive threshold that increases after each spike.

ALIF

Bases: SpikingNeuron

Adaptive Leaky Integrate-and-Fire neuron.

Extends the standard LIF neuron with an adaptive threshold that increases after each spike and decays exponentially between spikes. This mechanism implements spike-frequency adaptation.

Args: beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. rho: Adaptation variable decay rate. Controls how quickly the threshold relaxes back to its base value after a spike. b: Adaptation strength. Scales the contribution of the adaptation variable to the effective threshold. learn_beta: If True, beta becomes a learnable parameter. learn_rho: If True, rho becomes a learnable parameter. threshold: Base spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.

Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import ALIF >>> neuron = ALIF(beta=0.9, rho=0.95, b=0.1) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize ALIF neuron state.

Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).

Returns: State dict with 'mem' and 'adapt' initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Args: x: Input current [batch, features]. state: Dict with 'mem' and 'adapt' from previous timestep.

Returns: Tuple of (spikes, new_state).

Synaptic

Conductance-based LIF with dual-state dynamics (synaptic current + membrane potential).

Synaptic

Bases: SpikingNeuron

Synaptic (conductance-based) LIF neuron.

Extends the standard LIF model with an explicit synaptic current state variable. Input is first integrated into the synaptic current, which then drives the membrane potential. This two-state model captures more realistic post-synaptic dynamics.

Args: alpha: Synaptic current decay rate. Values closer to 1 give slower synaptic dynamics; closer to 0 gives faster decay. beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. learn_alpha: If True, alpha becomes a learnable parameter. learn_beta: If True, beta becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.

Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import Synaptic >>> neuron = Synaptic(alpha=0.8, beta=0.9) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize Synaptic neuron state.

Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).

Returns: State dict with 'syn' and 'mem' initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Args: x: Input current [batch, features]. state: Dict with 'syn' and 'mem' from previous timestep.

Returns: Tuple of (spikes, new_state).

Alpha

Dual-exponential synaptic model with alpha-function shaped PSPs.

Alpha

Bases: SpikingNeuron

Alpha (dual-exponential) spiking neuron.

Models a dual-exponential post-synaptic current by cascading two first-order synaptic filters (excitatory and inhibitory) before driving the membrane potential. This produces a current with a finite rise time followed by exponential decay.

Args: alpha: Synaptic current decay rate for both excitatory and inhibitory filters. Values closer to 1 give slower synaptic dynamics; closer to 0 gives faster decay. beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. learn_alpha: If True, alpha becomes a learnable parameter. learn_beta: If True, beta becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.

Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import Alpha >>> neuron = Alpha(alpha=0.85, beta=0.9) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize Alpha neuron state.

Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).

Returns: State dict with 'syn_exc', 'syn_inh', and 'mem' initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Args: x: Input current [batch, features]. state: Dict with 'syn_exc', 'syn_inh', and 'mem' from previous timestep.

Returns: Tuple of (spikes, new_state).

Recurrent LIF (RLeaky)

LIF with learnable recurrent feedback weight.

RLeaky

Bases: SpikingNeuron

Recurrent Leaky Integrate-and-Fire neuron.

Output spikes are fed back as additional input via a learnable recurrent weight V, giving the neuron an explicit recurrent connection in addition to the implicit memory from membrane decay.

Uses snnTorch-compatible delayed reset: the reset signal is computed from the previous membrane potential and detached from the computation graph.

Args: beta: Membrane potential decay rate. V: Recurrent weight. Scales the feedback spike signal. learn_beta: If True, beta becomes a learnable parameter. learn_V: If True, V becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.

Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import RLeaky >>> neuron = RLeaky(beta=0.9, learn_V=True) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize RLeaky neuron state.

Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).

Returns: State dict with 'mem' and 'spk' initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Uses delayed reset matching snnTorch: reset is computed from the previous membrane potential (detached from graph) before the membrane update.

Args: x: Input current [batch, features]. state: Dict with 'mem' and 'spk' from previous timestep.

Returns: Tuple of (spikes, new_state).

Recurrent Synaptic (RSynaptic)

Synaptic neuron with learnable recurrent feedback weight.

RSynaptic

Bases: SpikingNeuron

Recurrent Synaptic (conductance-based) LIF neuron.

Extends the Synaptic neuron model with recurrent spike feedback. The output spike is fed back into the synaptic current via a learnable weight V, providing explicit recurrence on top of the two-state (synaptic current + membrane) dynamics.

Uses snnTorch-compatible delayed reset: the reset signal is computed from the previous membrane potential and detached from the computation graph.

Args: alpha: Synaptic current decay rate. beta: Membrane potential decay rate. V: Recurrent weight. Scales the feedback spike signal. learn_alpha: If True, alpha becomes a learnable parameter. learn_beta: If True, beta becomes a learnable parameter. learn_V: If True, V becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.

Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import RSynaptic >>> neuron = RSynaptic(alpha=0.8, beta=0.9, learn_V=True) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)

init_state

init_state(batch_size: int, *shape) -> dict

Initialize RSynaptic neuron state.

Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).

Returns: State dict with 'syn', 'mem', and 'spk' initialized to zeros.

__call__

__call__(x: array, state: dict) -> tuple[mx.array, dict]

Forward one timestep.

Uses delayed reset matching snnTorch: reset is computed from the previous membrane potential (detached from graph) before the membrane update.

Args: x: Input current [batch, features]. state: Dict with 'syn', 'mem', and 'spk' from previous timestep.

Returns: Tuple of (spikes, new_state).