Surrogate Gradients¶
Surrogate gradient functions enable backpropagation through the non-differentiable Heaviside spike function. The forward pass uses a hard threshold; the backward pass uses a smooth approximation.
Factory Function¶
get_surrogate ¶
Get a surrogate gradient function by name.
Args: name: Name of the surrogate function. One of 'fast_sigmoid', 'arctan', 'straight_through'. scale: Scaling parameter controlling gradient sharpness.
Returns: A callable that computes the Heaviside step in the forward pass and a smooth surrogate gradient in the backward pass.
Raises: ValueError: If name is not in the registry.
Available surrogates: 'fast_sigmoid', 'arctan', 'straight_through', 'sigmoid', 'triangular'.
Available Surrogates¶
Fast Sigmoid¶
Default surrogate. Good balance of speed and gradient quality.
fast_sigmoid_surrogate ¶
Create a fast sigmoid surrogate gradient function.
The smooth approximation is the rational fast sigmoid:
approx(x) = 0.5 * scale * x / (1 + scale * |x|) + 0.5
Its derivative (which becomes the surrogate gradient) is:
d/dx approx = scale / (2 * (1 + scale * |x|)^2)
This matches snnTorch's surrogate.fast_sigmoid(slope=scale).
Args: scale: Controls the sharpness of the surrogate gradient. Larger values produce a sharper (more step-like) gradient. Default 25.0 matches snnTorch's default slope.
Returns: A callable with Heaviside forward and fast-sigmoid backward.
Arctan¶
Smoother gradient landscape, useful when training is unstable.
arctan_surrogate ¶
Create an arctan surrogate gradient function.
Matches snnTorch's surrogate.atan(alpha) gradient:
grad(x=0) = alpha / 2.
Args:
alpha: Controls the sharpness of the surrogate gradient.
With alpha=2.0, the peak gradient at threshold is 1.0.
Returns: A callable with Heaviside forward and arctan backward.
Sigmoid¶
Standard logistic sigmoid derivative as surrogate.
sigmoid_surrogate ¶
Create a sigmoid surrogate gradient function.
The backward pass uses the derivative of the sigmoid function, which has a bell-shaped curve centered at the threshold.
Args: slope: Controls the steepness of the sigmoid. Larger values produce a sharper transition.
Returns: A callable with Heaviside forward and sigmoid backward.
Triangular (Tent)¶
Localized gradient with compact support near threshold.
triangular_surrogate ¶
Create a triangular (tent) surrogate gradient function.
The backward pass uses a tent function centered at the threshold: grad = max(0, 1 - |x|)
This provides a positive, localized gradient near threshold that linearly decays to zero, giving a good trade-off between signal quality and stability.
Args: scale: Unused. Kept for API consistency with other surrogates.
Returns: A callable with Heaviside forward and triangular backward.
Straight-Through Estimator¶
Simplest surrogate — unit gradient everywhere.
straight_through_surrogate ¶
Create a straight-through estimator surrogate gradient function.
The gradient is passed through unchanged regardless of the input value.
This matches snnTorch's StraightThroughEstimator.
Args: scale: Unused. Kept for API consistency with other surrogates.
Returns: A callable with Heaviside forward and identity backward.