Skip to content

Encoding

Spike encoding methods convert continuous-valued data into binary spike trains for SNN processing.

Rate Encoding

Poisson rate coding — higher values produce higher spike probability.

rate_encode

rate_encode(data: array, num_steps: int, gain: float = 1.0, offset: float = 0.0, key: array | None = None) -> mx.array

Encode continuous data into Poisson spike trains.

Each input value is interpreted as a firing probability. At each timestep, a spike is generated with that probability.

Args: data: Input data with values in [0, 1], shape [batch, ...]. num_steps: Number of time steps to generate. gain: Multiplicative scaling applied to data before encoding. offset: Additive offset applied after gain. key: MLX random key. If None, uses default random state.

Returns: Spike trains of shape [num_steps, batch, ...] (time-first).

Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import rate_encode >>> data = mx.array([[0.8, 0.2], [0.5, 0.5]]) >>> spikes = rate_encode(data, num_steps=100) >>> spikes.shape (100, 2, 2)

Latency Encoding

Time-to-first-spike encoding — higher values spike earlier.

latency_encode

latency_encode(data: array, num_steps: int, tau: float = 5.0, normalize: bool = True, linear: bool = False) -> mx.array

Encode continuous data using time-to-first-spike latency coding.

Higher input values produce earlier spikes. Each neuron fires exactly once across the time window.

Args: data: Input data with values in [0, 1], shape [batch, ...]. num_steps: Number of time steps in the encoding window. tau: Time constant controlling the mapping from value to spike time (only used when linear=False). normalize: If True, normalize data to [0, 1] range. linear: If True, use linear mapping instead of exponential.

Returns: Spike trains of shape [num_steps, batch, ...] (time-first). Each spatial position has exactly one spike.

Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import latency_encode >>> data = mx.array([[0.9, 0.1], [0.5, 0.5]]) >>> spikes = latency_encode(data, num_steps=10) >>> spikes.shape (10, 2, 2)

Delta Encoding

Change detection — spikes when input changes exceed a threshold.

delta_encode

delta_encode(data: array, threshold: float = 0.1, off_spike: bool = True, padding: bool = True) -> mx.array

Encode continuous data using delta modulation.

Computes temporal differences between consecutive timesteps and generates spikes where the absolute change exceeds a threshold.

For single-step input (no time dimension), the data is returned as a single-timestep tensor with zero spikes (since no temporal difference can be computed).

Args: data: Input data. Either single-step [batch, ...] or temporal [time, batch, ...] format. threshold: Minimum absolute change required to emit a spike. Larger values produce sparser spike trains. off_spike: If True, emit -1 spikes when the signal decreases beyond -threshold. If False, only positive spikes (+1) are generated. padding: If True, pad the first timestep with zeros so the output shape matches the input. If False, the output has one fewer timestep than the input.

Returns: Spike array. For temporal input with padding=True, shape matches data shape [time, batch, ...]. With padding=False, shape is [time - 1, batch, ...]. For single-step input, returns [1, batch, ...].

Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import delta_encode >>> signal = mx.array([ ... [[0.0, 0.5]], ... [[0.2, 0.3]], ... [[0.5, 0.1]], ... ]) # shape [3, 1, 2] >>> spikes = delta_encode(signal, threshold=0.15) >>> spikes.shape [3, 1, 2]

Direct Encoding

Repeats static data across timesteps without stochastic conversion.

direct_encode

direct_encode(data: array, num_steps: int) -> mx.array

Repeat static data across multiple timesteps (direct encoding).

Each timestep receives an identical copy of the input data. This is the simplest encoding — useful when the network itself should learn temporal dynamics from constant input.

Args: data: Input data of shape [batch, ...] or [batch, C, H, W]. num_steps: Number of timesteps T to repeat.

Returns: Tensor of shape [T, batch, ...] (time-first).

Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import direct_encode >>> data = mx.ones((4, 10)) >>> out = direct_encode(data, num_steps=25) >>> out.shape [25, 4, 10]

Repeat Encoding

Tiles an existing spike pattern N times along the temporal dimension.

repeat_encode

repeat_encode(spikes: array, num_repeats: int) -> mx.array

Tile a spike pattern multiple times along the time axis.

Concatenates num_repeats copies of the input along axis 0, useful for extending short spike sequences to fill a longer simulation window.

Args: spikes: Spike tensor of shape [T, batch, ...] (time-first). num_repeats: Number of times to repeat the pattern.

Returns: Tensor of shape [T * num_repeats, batch, ...].

Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding import repeat_encode >>> spikes = mx.ones((5, 4, 10)) >>> out = repeat_encode(spikes, num_repeats=3) >>> out.shape [15, 4, 10]

EEG Encoder

Specialized encoder for EEG biomedical signals.

EEGEncoder

Encode EEG signals into spike trains.

Supports multiple encoding strategies commonly used for neural signal processing in spiking neural networks:

  • "rate": Amplitude-to-firing-rate mapping via Poisson sampling.
  • "delta": Temporal difference coding; spikes when consecutive samples differ by more than a threshold.
  • "threshold_crossing": Multi-level threshold crossing using positive and negative thresholds.

Attributes: method: Encoding strategy name. num_steps: Number of output time steps for rate coding. threshold: Threshold for delta and threshold-crossing methods.

Examples: >>> import mlx.core as mx >>> from mlxsnn.encoding.medical.eeg import EEGEncoder >>> encoder = EEGEncoder(method="rate", num_steps=50) >>> signal = mx.random.normal(shape=(4, 64, 256)) >>> spikes = encoder(signal) >>> spikes.shape [50, 4, 64]

__init__

__init__(method: str = 'rate', num_steps: int = 100, threshold: float = 0.5) -> None

Initialize the EEG encoder.

Args: method: Encoding strategy. One of "rate", "delta", or "threshold_crossing". num_steps: Number of output time steps. Used directly by rate coding; delta and threshold-crossing methods resample the input to this many steps. threshold: Spike threshold. For delta coding, a spike is emitted when the absolute temporal difference exceeds this value. For threshold crossing, this sets the magnitude of the positive and negative crossing levels.

Raises: ValueError: If method is not one of the supported strategies.

__call__

__call__(signal: array) -> mx.array

Encode a continuous EEG signal into a spike train.

Args: signal: Raw EEG signal with shape [channels, timepoints] or [batch, channels, timepoints].

Returns: Spike array in time-first format [num_steps, batch, channels].

Raises: ValueError: If signal does not have 2 or 3 dimensions.