Encoding¶
Spike encoding methods convert continuous-valued data into binary spike trains for SNN processing.
Rate Encoding¶
Poisson rate coding — higher values produce higher spike probability.
rate_encode ¶
rate_encode(data: array, num_steps: int, gain: float = 1.0, offset: float = 0.0, key: array | None = None) -> mx.array
Encode continuous data into Poisson spike trains.
Each input value is interpreted as a firing probability. At each timestep, a spike is generated with that probability.
Args:
data: Input data with values in [0, 1], shape [batch, ...].
num_steps: Number of time steps to generate.
gain: Multiplicative scaling applied to data before encoding.
offset: Additive offset applied after gain.
key: MLX random key. If None, uses default random state.
Returns:
Spike trains of shape [num_steps, batch, ...] (time-first).
Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import rate_encode >>> data = mx.array([[0.8, 0.2], [0.5, 0.5]]) >>> spikes = rate_encode(data, num_steps=100) >>> spikes.shape (100, 2, 2)
Latency Encoding¶
Time-to-first-spike encoding — higher values spike earlier.
latency_encode ¶
latency_encode(data: array, num_steps: int, tau: float = 5.0, normalize: bool = True, linear: bool = False) -> mx.array
Encode continuous data using time-to-first-spike latency coding.
Higher input values produce earlier spikes. Each neuron fires exactly once across the time window.
Args:
data: Input data with values in [0, 1], shape [batch, ...].
num_steps: Number of time steps in the encoding window.
tau: Time constant controlling the mapping from value to spike
time (only used when linear=False).
normalize: If True, normalize data to [0, 1] range.
linear: If True, use linear mapping instead of exponential.
Returns:
Spike trains of shape [num_steps, batch, ...] (time-first).
Each spatial position has exactly one spike.
Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import latency_encode >>> data = mx.array([[0.9, 0.1], [0.5, 0.5]]) >>> spikes = latency_encode(data, num_steps=10) >>> spikes.shape (10, 2, 2)
Delta Encoding¶
Change detection — spikes when input changes exceed a threshold.
delta_encode ¶
delta_encode(data: array, threshold: float = 0.1, off_spike: bool = True, padding: bool = True) -> mx.array
Encode continuous data using delta modulation.
Computes temporal differences between consecutive timesteps and generates spikes where the absolute change exceeds a threshold.
For single-step input (no time dimension), the data is returned as a single-timestep tensor with zero spikes (since no temporal difference can be computed).
Args:
data: Input data. Either single-step [batch, ...] or
temporal [time, batch, ...] format.
threshold: Minimum absolute change required to emit a spike.
Larger values produce sparser spike trains.
off_spike: If True, emit -1 spikes when the signal
decreases beyond -threshold. If False, only positive
spikes (+1) are generated.
padding: If True, pad the first timestep with zeros so the
output shape matches the input. If False, the output has
one fewer timestep than the input.
Returns:
Spike array. For temporal input with padding=True, shape
matches data shape [time, batch, ...]. With
padding=False, shape is [time - 1, batch, ...].
For single-step input, returns [1, batch, ...].
Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import delta_encode >>> signal = mx.array([ ... [[0.0, 0.5]], ... [[0.2, 0.3]], ... [[0.5, 0.1]], ... ]) # shape [3, 1, 2] >>> spikes = delta_encode(signal, threshold=0.15) >>> spikes.shape [3, 1, 2]
Direct Encoding¶
Repeats static data across timesteps without stochastic conversion.
direct_encode ¶
Repeat static data across multiple timesteps (direct encoding).
Each timestep receives an identical copy of the input data. This is the simplest encoding — useful when the network itself should learn temporal dynamics from constant input.
Args:
data: Input data of shape [batch, ...] or [batch, C, H, W].
num_steps: Number of timesteps T to repeat.
Returns:
Tensor of shape [T, batch, ...] (time-first).
Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import direct_encode >>> data = mx.ones((4, 10)) >>> out = direct_encode(data, num_steps=25) >>> out.shape [25, 4, 10]
Repeat Encoding¶
Tiles an existing spike pattern N times along the temporal dimension.
repeat_encode ¶
Tile a spike pattern multiple times along the time axis.
Concatenates num_repeats copies of the input along axis 0,
useful for extending short spike sequences to fill a longer
simulation window.
Args:
spikes: Spike tensor of shape [T, batch, ...] (time-first).
num_repeats: Number of times to repeat the pattern.
Returns:
Tensor of shape [T * num_repeats, batch, ...].
Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import repeat_encode >>> spikes = mx.ones((5, 4, 10)) >>> out = repeat_encode(spikes, num_repeats=3) >>> out.shape [15, 4, 10]
EEG Encoder¶
Specialized encoder for EEG biomedical signals.
EEGEncoder ¶
Encode EEG signals into spike trains.
Supports multiple encoding strategies commonly used for neural signal processing in spiking neural networks:
"rate": Amplitude-to-firing-rate mapping via Poisson sampling."delta": Temporal difference coding; spikes when consecutive samples differ by more than a threshold."threshold_crossing": Multi-level threshold crossing using positive and negative thresholds.
Attributes: method: Encoding strategy name. num_steps: Number of output time steps for rate coding. threshold: Threshold for delta and threshold-crossing methods.
Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding.medical.eeg import EEGEncoder >>> encoder = EEGEncoder(method="rate", num_steps=50) >>> signal = mx.random.normal(shape=(4, 64, 256)) >>> spikes = encoder(signal) >>> spikes.shape [50, 4, 64]
__init__ ¶
Initialize the EEG encoder.
Args:
method: Encoding strategy. One of "rate", "delta",
or "threshold_crossing".
num_steps: Number of output time steps. Used directly by
rate coding; delta and threshold-crossing methods
resample the input to this many steps.
threshold: Spike threshold. For delta coding, a spike is
emitted when the absolute temporal difference exceeds
this value. For threshold crossing, this sets the
magnitude of the positive and negative crossing levels.
Raises:
ValueError: If method is not one of the supported
strategies.
__call__ ¶
Encode a continuous EEG signal into a spike train.
Args:
signal: Raw EEG signal with shape [channels, timepoints]
or [batch, channels, timepoints].
Returns:
Spike array in time-first format
[num_steps, batch, channels].
Raises:
ValueError: If signal does not have 2 or 3 dimensions.