Skip to content

Surrogate Gradients

Surrogate gradient functions enable backpropagation through the non-differentiable Heaviside spike function. The forward pass uses a hard threshold; the backward pass uses a smooth approximation.

Factory Function

get_surrogate

get_surrogate(name: str, scale: float = 2.0)

Get a surrogate gradient function by name.

Args: name: Name of the surrogate function. One of 'fast_sigmoid', 'arctan', 'straight_through'. scale: Scaling parameter controlling gradient sharpness.

Returns: A callable that computes the Heaviside step in the forward pass and a smooth surrogate gradient in the backward pass.

Raises: ValueError: If name is not in the registry.

Available surrogates: 'fast_sigmoid', 'arctan', 'straight_through', 'sigmoid', 'triangular'.

Available Surrogates

Fast Sigmoid

Default surrogate. Good balance of speed and gradient quality.

\[\frac{\partial S}{\partial U} \approx \frac{\alpha}{(1 + \alpha|U - V_{thr}|)^2}\]

fast_sigmoid_surrogate

fast_sigmoid_surrogate(scale: float = 25.0)

Create a fast sigmoid surrogate gradient function.

The smooth approximation is the rational fast sigmoid:

approx(x) = 0.5 * scale * x / (1 + scale * |x|) + 0.5

Its derivative (which becomes the surrogate gradient) is:

d/dx approx = scale / (2 * (1 + scale * |x|)^2)

This matches snnTorch's surrogate.fast_sigmoid(slope=scale).

Args: scale: Controls the sharpness of the surrogate gradient. Larger values produce a sharper (more step-like) gradient. Default 25.0 matches snnTorch's default slope.

Returns: A callable with Heaviside forward and fast-sigmoid backward.

Arctan

Smoother gradient landscape, useful when training is unstable.

arctan_surrogate

arctan_surrogate(alpha: float = 2.0)

Create an arctan surrogate gradient function.

Matches snnTorch's surrogate.atan(alpha) gradient: grad(x=0) = alpha / 2.

Args: alpha: Controls the sharpness of the surrogate gradient. With alpha=2.0, the peak gradient at threshold is 1.0.

Returns: A callable with Heaviside forward and arctan backward.

Sigmoid

Standard logistic sigmoid derivative as surrogate.

sigmoid_surrogate

sigmoid_surrogate(slope: float = 25.0)

Create a sigmoid surrogate gradient function.

The backward pass uses the derivative of the sigmoid function, which has a bell-shaped curve centered at the threshold.

Args: slope: Controls the steepness of the sigmoid. Larger values produce a sharper transition.

Returns: A callable with Heaviside forward and sigmoid backward.

Triangular (Tent)

Localized gradient with compact support near threshold.

triangular_surrogate

triangular_surrogate(scale: float = 1.0)

Create a triangular (tent) surrogate gradient function.

The backward pass uses a tent function centered at the threshold: grad = max(0, 1 - |x|)

This provides a positive, localized gradient near threshold that linearly decays to zero, giving a good trade-off between signal quality and stability.

Args: scale: Unused. Kept for API consistency with other surrogates.

Returns: A callable with Heaviside forward and triangular backward.

Straight-Through Estimator

Simplest surrogate — unit gradient everywhere.

straight_through_surrogate

straight_through_surrogate(scale: float = 1.0)

Create a straight-through estimator surrogate gradient function.

The gradient is passed through unchanged regardless of the input value. This matches snnTorch's StraightThroughEstimator.

Args: scale: Unused. Kept for API consistency with other surrogates.

Returns: A callable with Heaviside forward and identity backward.