Neurons¶
Spiking neuron models are the core building blocks of SNNs. All neurons inherit from SpikingNeuron and follow the same interface: (x, state) → (spk, new_state).
Base Class¶
SpikingNeuron ¶
Bases: Module
Abstract base class for spiking neuron models.
Args: threshold: Membrane potential threshold for spike generation. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: How to reset membrane after a spike. 'subtract' — subtract threshold from membrane potential. 'zero' — reset membrane to zero. 'none' — no reset (useful for output layers). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for the surrogate gradient.
Examples:
Subclasses must implement init_state and __call__::
class MyNeuron(SpikingNeuron):
def init_state(self, batch_size, features):
return {"mem": mx.zeros((batch_size, features))}
def __call__(self, x, state):
mem = state["mem"] + x
spk = self.fire(mem)
mem = self.reset(mem, spk)
return spk, {"mem": mem}
__init__ ¶
__init__(threshold: float = 1.0, learn_threshold: bool = False, reset_mechanism: str = 'subtract', surrogate_fn: str = 'arctan', surrogate_scale: float = 2.0)
init_state ¶
Initialize neuron hidden state.
Args: batch_size: Number of samples in the batch. *args: Additional shape dimensions (e.g., feature size).
Returns: A dictionary of state tensors initialized to zeros.
__call__ ¶
Forward pass: compute spikes from input and previous state.
Args:
x: Input current of shape [batch, features].
state: Dictionary of state tensors from the previous timestep.
Returns:
A tuple of (spikes, new_state) where spikes is a binary
array of the same shape as x.
fire ¶
Generate spikes using surrogate gradient.
In the forward pass this applies the Heaviside step function. In the backward pass the surrogate gradient is used.
Args: mem: Membrane potential array.
Returns: Binary spike array (1 where mem >= threshold, 0 otherwise).
reset ¶
Apply reset mechanism after spike generation.
Args:
mem: Membrane potential before reset.
spk: Binary spike array from fire().
Returns: Membrane potential after reset.
Leaky Integrate-and-Fire (LIF)¶
The most commonly used spiking neuron. Membrane dynamics:
Leaky ¶
Bases: SpikingNeuron
Leaky Integrate-and-Fire neuron.
Args: beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. learn_beta: If True, beta becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.
Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import Leaky >>> lif = Leaky(beta=0.9) >>> state = lif.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = lif(x, state)
init_state ¶
Initialize LIF neuron state.
Args:
batch_size: Number of samples in the batch.
*shape: Feature dimensions. Can be a single int for FC layers
(e.g., init_state(B, 128)) or spatial dims for conv
layers (e.g., init_state(B, 64, 64, 128)).
Returns: State dict with 'mem' initialized to zeros.
__call__ ¶
Forward one timestep.
Args:
x: Input current [batch, features].
state: Dict with 'mem' from previous timestep.
Returns: Tuple of (spikes, new_state).
Integrate-and-Fire (IF)¶
Non-leaky variant — no membrane decay (\(\beta = 1\)).
IF ¶
Bases: SpikingNeuron
Integrate-and-Fire neuron (no leak).
This is the simplest spiking neuron model. The membrane potential integrates input current without decay and fires when threshold is reached.
Args: threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.
Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import IF >>> neuron = IF(threshold=1.0) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) * 0.5 >>> spk, state = neuron(x, state) # No spike yet (mem=0.5) >>> spk, state = neuron(x, state) # Spike! (mem=1.0 >= threshold)
init_state ¶
Initialize IF neuron state.
Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).
Returns: State dict with 'mem' initialized to zeros.
__call__ ¶
Forward one timestep.
Args:
x: Input current [batch, features].
state: Dict with 'mem' from previous timestep.
Returns: Tuple of (spikes, new_state).
Izhikevich¶
2D dynamical system with biologically realistic spiking patterns.
Supports presets: Regular Spiking (RS), Intrinsically Bursting (IB), Chattering (CH), Fast Spiking (FS).
Izhikevich ¶
Bases: SpikingNeuron
Izhikevich spiking neuron with two-dimensional dynamics.
This model uses two coupled variables (v, u) to reproduce a rich repertoire of neural spiking behaviours while remaining computationally efficient.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
float
|
Time scale of the recovery variable u. Smaller values result in slower recovery. |
0.02
|
b
|
float
|
Sensitivity of the recovery variable u to sub-threshold fluctuations of the membrane potential v. |
0.2
|
c
|
float
|
After-spike reset value of v (in mV). |
-65.0
|
d
|
float
|
After-spike increment applied to u. |
8.0
|
dt
|
float
|
Integration timestep. The default of 0.5 ms provides a good balance between accuracy and speed; use smaller values for higher fidelity. |
0.5
|
preset
|
str or None
|
If given, must be a key in |
None
|
surrogate_fn
|
str
|
Surrogate gradient function name. |
'arctan'
|
surrogate_scale
|
float
|
Scale parameter for the surrogate gradient. |
0.1
|
Examples:
>>> import mlx.core as mx
>>> from mlxsnn.neurons.izhikevich import Izhikevich
>>> neuron = Izhikevich(preset="RS")
>>> state = neuron.init_state(4, 128)
>>> x = mx.ones((4, 128)) * 10.0
>>> spk, state = neuron(x, state)
init_state ¶
Initialize Izhikevich neuron state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of samples in the batch. |
required |
*shape
|
Feature dimensions (single int or spatial dims). |
()
|
Returns:
| Type | Description |
|---|---|
dict
|
State dictionary with keys |
__call__ ¶
Forward one timestep of the Izhikevich model.
The membrane potential is integrated using a forward-Euler step
with step size dt. When v >= 30 a spike is emitted, v
is reset to c, and u is incremented by d.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input current of shape |
required |
state
|
dict
|
Dictionary with |
required |
Returns:
| Type | Description |
|---|---|
tuple[array, dict]
|
A tuple |
Adaptive LIF (ALIF)¶
LIF with adaptive threshold that increases after each spike.
ALIF ¶
Bases: SpikingNeuron
Adaptive Leaky Integrate-and-Fire neuron.
Extends the standard LIF neuron with an adaptive threshold that increases after each spike and decays exponentially between spikes. This mechanism implements spike-frequency adaptation.
Args: beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. rho: Adaptation variable decay rate. Controls how quickly the threshold relaxes back to its base value after a spike. b: Adaptation strength. Scales the contribution of the adaptation variable to the effective threshold. learn_beta: If True, beta becomes a learnable parameter. learn_rho: If True, rho becomes a learnable parameter. threshold: Base spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.
Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import ALIF >>> neuron = ALIF(beta=0.9, rho=0.95, b=0.1) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)
init_state ¶
Initialize ALIF neuron state.
Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).
Returns: State dict with 'mem' and 'adapt' initialized to zeros.
__call__ ¶
Forward one timestep.
Args:
x: Input current [batch, features].
state: Dict with 'mem' and 'adapt' from previous timestep.
Returns: Tuple of (spikes, new_state).
Synaptic¶
Conductance-based LIF with dual-state dynamics (synaptic current + membrane potential).
Synaptic ¶
Bases: SpikingNeuron
Synaptic (conductance-based) LIF neuron.
Extends the standard LIF model with an explicit synaptic current state variable. Input is first integrated into the synaptic current, which then drives the membrane potential. This two-state model captures more realistic post-synaptic dynamics.
Args: alpha: Synaptic current decay rate. Values closer to 1 give slower synaptic dynamics; closer to 0 gives faster decay. beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. learn_alpha: If True, alpha becomes a learnable parameter. learn_beta: If True, beta becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.
Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import Synaptic >>> neuron = Synaptic(alpha=0.8, beta=0.9) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)
init_state ¶
Initialize Synaptic neuron state.
Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).
Returns: State dict with 'syn' and 'mem' initialized to zeros.
__call__ ¶
Forward one timestep.
Args:
x: Input current [batch, features].
state: Dict with 'syn' and 'mem' from previous timestep.
Returns: Tuple of (spikes, new_state).
Alpha¶
Dual-exponential synaptic model with alpha-function shaped PSPs.
Alpha ¶
Bases: SpikingNeuron
Alpha (dual-exponential) spiking neuron.
Models a dual-exponential post-synaptic current by cascading two first-order synaptic filters (excitatory and inhibitory) before driving the membrane potential. This produces a current with a finite rise time followed by exponential decay.
Args: alpha: Synaptic current decay rate for both excitatory and inhibitory filters. Values closer to 1 give slower synaptic dynamics; closer to 0 gives faster decay. beta: Membrane potential decay rate. Values closer to 1 give longer memory; closer to 0 gives faster decay. learn_alpha: If True, alpha becomes a learnable parameter. learn_beta: If True, beta becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.
Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import Alpha >>> neuron = Alpha(alpha=0.85, beta=0.9) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)
init_state ¶
Initialize Alpha neuron state.
Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).
Returns: State dict with 'syn_exc', 'syn_inh', and 'mem' initialized to zeros.
__call__ ¶
Forward one timestep.
Args:
x: Input current [batch, features].
state: Dict with 'syn_exc', 'syn_inh', and 'mem' from
previous timestep.
Returns: Tuple of (spikes, new_state).
Recurrent LIF (RLeaky)¶
LIF with learnable recurrent feedback weight.
RLeaky ¶
Bases: SpikingNeuron
Recurrent Leaky Integrate-and-Fire neuron.
Output spikes are fed back as additional input via a learnable recurrent weight V, giving the neuron an explicit recurrent connection in addition to the implicit memory from membrane decay.
Uses snnTorch-compatible delayed reset: the reset signal is computed from the previous membrane potential and detached from the computation graph.
Args: beta: Membrane potential decay rate. V: Recurrent weight. Scales the feedback spike signal. learn_beta: If True, beta becomes a learnable parameter. learn_V: If True, V becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.
Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import RLeaky >>> neuron = RLeaky(beta=0.9, learn_V=True) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)
init_state ¶
Initialize RLeaky neuron state.
Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).
Returns: State dict with 'mem' and 'spk' initialized to zeros.
__call__ ¶
Forward one timestep.
Uses delayed reset matching snnTorch: reset is computed from the previous membrane potential (detached from graph) before the membrane update.
Args:
x: Input current [batch, features].
state: Dict with 'mem' and 'spk' from previous timestep.
Returns: Tuple of (spikes, new_state).
Recurrent Synaptic (RSynaptic)¶
Synaptic neuron with learnable recurrent feedback weight.
RSynaptic ¶
Bases: SpikingNeuron
Recurrent Synaptic (conductance-based) LIF neuron.
Extends the Synaptic neuron model with recurrent spike feedback. The output spike is fed back into the synaptic current via a learnable weight V, providing explicit recurrence on top of the two-state (synaptic current + membrane) dynamics.
Uses snnTorch-compatible delayed reset: the reset signal is computed from the previous membrane potential and detached from the computation graph.
Args: alpha: Synaptic current decay rate. beta: Membrane potential decay rate. V: Recurrent weight. Scales the feedback spike signal. learn_alpha: If True, alpha becomes a learnable parameter. learn_beta: If True, beta becomes a learnable parameter. learn_V: If True, V becomes a learnable parameter. threshold: Spike threshold voltage. learn_threshold: If True, threshold becomes a learnable parameter. reset_mechanism: Reset method after spike ('subtract', 'zero', 'none'). surrogate_fn: Surrogate gradient function name or callable. surrogate_scale: Scale parameter for surrogate gradient.
Examples: >>> import mlx.core as mx >>> from mlxsnn.neurons import RSynaptic >>> neuron = RSynaptic(alpha=0.8, beta=0.9, learn_V=True) >>> state = neuron.init_state(4, 128) >>> x = mx.ones((4, 128)) >>> spk, state = neuron(x, state)
init_state ¶
Initialize RSynaptic neuron state.
Args: batch_size: Number of samples in the batch. *shape: Feature dimensions (single int or spatial dims).
Returns: State dict with 'syn', 'mem', and 'spk' initialized to zeros.
__call__ ¶
Forward one timestep.
Uses delayed reset matching snnTorch: reset is computed from the previous membrane potential (detached from graph) before the membrane update.
Args:
x: Input current [batch, features].
state: Dict with 'syn', 'mem', and 'spk' from previous timestep.
Returns: Tuple of (spikes, new_state).