Skip to content

Datasets

Neuromorphic dataset loaders for event-driven vision and audio tasks.

Optional dependency

Dataset loaders require additional packages:

pip install mlx-snn[datasets]  # installs aedat, h5py

DVS-Gesture

IBM DVS128 Gesture dataset — 11 hand gestures captured with a DVS camera.

DVSGestureDataset

DVS128 Gesture dataset loader for mlx-snn.

Loads event data from .npy files, bins into dense frames, and returns mx.array tensors suitable for spiking network input.

Args: root: Path to the DVSGesture directory that contains ibmGestureTrain/ and ibmGestureTest/. train: If True, load the training split; otherwise the test split. num_steps: Number of temporal bins (timesteps T). spatial_size: Target spatial resolution (H, W). If different from the native 128x128, frames are resized. transform: Optional callable applied to each frame tensor after conversion to mx.array. Receives and must return an mx.array of shape (T, H, W, 2).

Attributes: samples: List of (file_path, label) tuples. classes: List of class indices [0, 1, ..., 10]. num_classes: 11. sensor_size: Native sensor resolution (128, 128).

Examples: >>> ds = DVSGestureDataset("/data/DVSGesture", train=True, ... num_steps=16) >>> len(ds) 1077 >>> frames, label = ds[0] >>> frames.shape [16, 128, 128, 2]

__getitem__

__getitem__(idx: int) -> Tuple[mx.array, int]

Load and bin a single sample.

Args: idx: Sample index.

Returns: A tuple (frames, label) where frames is an mx.array of shape (T, H, W, 2) and label is an int in [0, 10].

CIFAR10-DVS

CIFAR-10 converted to neuromorphic events via a DVS camera.

CIFAR10DVSDataset

CIFAR10-DVS dataset loader for mlx-snn.

Loads event data from .aedat4 files using the lightweight aedat package, bins into dense frames, and returns mx.array tensors.

Because CIFAR10-DVS has no official train/test split, the loader deterministically assigns 90 % of samples to training and 10 % to testing (controlled by split_seed).

Args: root: Path to the CIFAR10DVS directory containing class sub-directories (airplane/, automobile/, etc.). train: If True, use the training portion of the split. num_steps: Number of temporal bins (timesteps T). spatial_size: Target spatial resolution (H, W). transform: Optional callable applied to each frame tensor after conversion to mx.array. split_ratio: Fraction of data used for training (default 0.9). split_seed: Random seed for the train/test split (default 42).

Attributes: samples: List of (file_path, label) tuples. classes: List of class names. num_classes: 10. sensor_size: Native sensor resolution (128, 128).

Examples: >>> ds = CIFAR10DVSDataset("/data/CIFAR10DVS", train=True, ... num_steps=10) >>> frames, label = ds[0] >>> frames.shape [10, 128, 128, 2]

__getitem__

__getitem__(idx: int) -> Tuple[mx.array, int]

Load and bin a single CIFAR10-DVS sample.

Args: idx: Sample index.

Returns: (frames, label) where frames is mx.array of shape (T, H, W, 2) and label is int.

N-MNIST

Neuromorphic MNIST — handwritten digits captured with saccading DVS sensor.

NMNISTDataset

Neuromorphic MNIST dataset loader for mlx-snn.

Loads event data from .bin files, bins events into dense frames, and returns mx.array tensors suitable for spiking network input.

Args: root: Path to the NMNIST directory containing Train/ and Test/ sub-directories. train: If True, load the training split; otherwise the test split. num_steps: Number of temporal bins (timesteps T). spatial_size: Target spatial resolution (H, W). Defaults to the native (34, 34). transform: Optional callable applied to each frame tensor after conversion to mx.array.

Attributes: samples: List of (file_path, label) tuples. classes: List of class indices [0, 1, ..., 9]. num_classes: 10. sensor_size: Native sensor resolution (34, 34).

Examples: >>> ds = NMNISTDataset("/data/NMNIST", train=True, num_steps=20) >>> len(ds) 60000 >>> frames, label = ds[0] >>> frames.shape [20, 34, 34, 2]

__getitem__

__getitem__(idx: int) -> Tuple[mx.array, int]

Load and bin a single N-MNIST sample.

Args: idx: Sample index.

Returns: (frames, label) where frames is mx.array of shape (T, H, W, 2) and label is int.

SHD (Spiking Heidelberg Digits)

Spoken digit recognition from cochlea-like spike encoding.

SHDDataset

Spiking Heidelberg Digits dataset loader.

Loads spike times from HDF5 files and bins them into dense binary tensors of shape (T, 700).

Args: root: Path to directory containing shd_train.h5 and shd_test.h5. train: If True, load the training split; otherwise test. num_steps: Number of temporal bins (timesteps T). max_time: Maximum spike time in seconds. Spikes after this time are discarded. If None, uses the maximum time found in the data. num_units: Number of input channels (default 700). transform: Optional callable applied to each sample after conversion to mx.array.

Examples: >>> from mlxsnn.datasets import SHDDataset >>> ds = SHDDataset("./data/shd", train=True, num_steps=100) >>> spikes, label = ds[0] >>> spikes.shape [100, 700]

__getitem__

__getitem__(idx: int) -> Tuple[mx.array, int]

Get a single sample.

Args: idx: Sample index.

Returns: Tuple of (spikes, label) where spikes is mx.array of shape [T, num_units] and label is an integer.

Data Utilities

Dataloader

EventDataloader

Simple batch iterator for MLX event-based datasets.

Produces batches of (frames, labels) as mx.array tensors without depending on PyTorch or any external data-loading library.

Args: dataset: A dataset object that supports __len__ and __getitem__, returning (mx.array, int) pairs. batch_size: Number of samples per batch. shuffle: Whether to shuffle the dataset at the start of each epoch. drop_last: If True, drop the final incomplete batch. seed: Random seed for shuffling reproducibility.

Yields: Tuple of (batch_frames, batch_labels) where batch_frames has shape (B, T, H, W, C) and batch_labels has shape (B,).

Examples: >>> loader = EventDataloader(dataset, batch_size=32, shuffle=True) >>> for frames, labels in loader: ... print(frames.shape, labels.shape)

Frame Processing

events_to_frames

events_to_frames(events: ndarray, num_steps: int, sensor_size: Tuple[int, int], num_polarities: int = 2) -> np.ndarray

Convert raw events to temporally-binned frames.

Divides the full event time range into num_steps equal bins and accumulates event counts per pixel and polarity into dense frames.

Args: events: Structured array or (N, 4) array with columns [x, y, polarity, timestamp]. Timestamps must be sorted in non-decreasing order. Polarity values are integers in [0, num_polarities). num_steps: Number of temporal bins (T). sensor_size: Spatial resolution as (H, W). num_polarities: Number of polarity channels (default 2).

Returns: frames: np.ndarray of shape (T, H, W, P) with dtype np.float32, where P = num_polarities.

Examples: >>> import numpy as np >>> events = np.array([[10, 20, 1, 100], ... [11, 21, 0, 200]], dtype=np.float64) >>> frames = events_to_frames(events, num_steps=2, ... sensor_size=(34, 34)) >>> frames.shape (2, 34, 34, 2)

resize_frames

resize_frames(frames: ndarray, target_size: Tuple[int, int]) -> np.ndarray

Resize spatial dimensions of a frame tensor using area-based binning.

Uses simple block averaging (area interpolation) which is appropriate for event count data. No external dependency required.

Args: frames: (T, H, W, C) float array. target_size: (H_new, W_new).

Returns: Resized frames with shape (T, H_new, W_new, C).