All the layers implemented in this package can be used similar to torch.nn layers in your implementations.

Main spiking layer


class SpikingLayer(threshold: float = 1.0, threshold_low: Optional[float] = - 1.0, membrane_subtract: Optional[float] = None, batch_size: Optional[int] = None, membrane_reset=False)

Pytorch implementation of a spiking neuron with learning enabled. This class is the base class for any layer that need to implement integrate-and-fire operations.

  • threshold – Spiking threshold of the neuron.

  • threshold_low – Lower bound for membrane potential.

  • membrane_subtract – The amount to subtract from the membrane potential upon spiking. Default is equal to threshold. Ignored if membrane_reset is set.

  • negative_spikes – Implement a linear transfer function through negative spiking. Ignored if membrane_reset is set.

  • batch_size – The batch size. Needed to distinguish between timesteps and batch dimension.

  • membrane_reset – bool, if True, reset the membrane to 0 on spiking.

  • layer_name – The name of the layer.


Returns the output shape for passthrough implementation





reset_states(shape=None, randomize=False)

Reset the state of all neurons in this layer

synaptic_output(input_spikes: torch.Tensor) → torch.Tensor

This method needs to be overridden/defined by the child class Default implementation is pass through


input_spikes – torch.Tensor input to the layer.


torch.Tensor - synaptic output current

Auxiliary spiking layers


class Cropping2dLayer(cropping: Union[numpy.ndarray, List, Tuple] = ((0, 0), (0, 0)))

Crop input image by

Crop input to the the rectangle dimensions


cropping – ((top, bottom), (left, right))

get_output_shape(input_shape: Tuple) → Tuple

Retuns the output dimensions


input_shape – (channels, height, width)


(channels, height, width)


class SpikingMaxPooling2dLayer(pool_size: Union[numpy.ndarray, List, Tuple], strides: Optional[Union[numpy.ndarray, List, Tuple]] = None, padding: Union[numpy.ndarray, List, Tuple] = (0, 0, 0, 0))

Torch implementation of SpikingMaxPooling

get_output_shape(input_shape: Tuple) → Tuple

Returns the shape of output, given an input to this layer


input_shape – (channels, height, width)


(channelsOut, height_out, width_out)


class InputLayer(input_shape: Union[numpy.ndarray, List, Tuple], layer_name='input')

Place holder layer, used typically to acquire some statistics on the input


image_shape – Input image dimensions

get_output_shape(input_shape: Tuple) → Tuple

Retuns the output dimensions


input_shape – (channels, height, width)


(channels, height, width)

Hybrid layers

The hybrid layers have inputs and outputs of different formats (eg. take analog values as inputs and produce spikes as outputs.)


class Img2SpikeLayer(image_shape, tw: int = 100, max_rate: float = 1000, norm: float = 255.0, squeeze: bool = False, negative_spikes: bool = False)

Layer to convert Images to Spikes

Layer converts images to spikes

  • image_shape – tuple image shape

  • tw – int Time window length

  • max_rate – maximum firing rate of neurons

  • layer_name – string layer name

  • norm – the supposed maximum value of the input (default 255.0)

  • squeeze – whether to remove singleton dimensions from the input

  • negative_spikes – whether to allow negative spikes in response to negative input


class Sig2SpikeLayer(channels_in, tw: int = 1, norm_level: float = 1, spk_out: bool = True)

Layer to convert analog Signals to Spikes

Layer converts analog signals to spikes

  • channels_in – number of channels in the analog signal

  • tw – int number of time steps for each sample of the signal (up sampling)

  • layer_name – string layer name

ANN layers

These are utility layers used in the training of ANNs, in order to provide specific features suitable for SNN conversion.


class NeuromorphicReLU(quantize=True, fanout=1, stochastic_rounding=False)

NeuromorphicReLU layer. This layer is NOT used for Sinabs networks; it’s useful while training analogue pyTorch networks for future use with Sinabs.

  • quantize – Whether or not to quantize the output (i.e. floor it to the integer below), in order to mimic spiking behavior.

  • fanout – Useful when computing the number of SynOps of a quantized NeuromorphicReLU. The activity can be accessed through NeuromorphicReLU.activity, and is multiplied by the value of fanout.

  • stochastic_rounding – Upon quantization, should the value be rounded stochastically or floored Only done during training. During evaluation mode, the value is simply floored

Initializes internal Module state, shared by both nn.Module and ScriptModule.


class QuantizeLayer(quantize=True)

Layer that quantizes the input, i.e. returns floor(input).


quantize – If False, this layer will do nothing.

Initializes internal Module state, shared by both nn.Module and ScriptModule.


class SumPool2d(kernel_size, stride=None, ceil_mode=False)

Non-spiking sumpooling layer to be used in analogue Torch models. It is identical to torch.nn.LPPool2d with p=1.

  • kernel_size – the size of the window

  • stride – the stride of the window. Default value is kernel_size

  • ceil_mode – when True, will use ceil instead of floor to compute the output shape

Initializes internal Module state, shared by both nn.Module and ScriptModule.


Quantization tools



PyTorch-compatible function that applies a floor() operation on the input, while providing a surrogate gradient (equivalent to that of a linear function) in the backward pass.



PyTorch-compatible function that applies stochastic rounding. The input x is quantized to ceil(x) with probability (x - floor(x)), and to floor(x) otherwise. The backward pass is provided as a surrogate gradient (equivalent to that of a linear function).

Thresholding tools


threshold_subtract(data, threshold=1, window=0.5)

PyTorch-compatible function that returns the number of spikes emitted, given a membrane potential value and in a “threshold subtracting” regime. In other words, the integer division of the input by the threshold is returned. In the backward pass, the gradient is zero if the membrane is at least threshold - window, and is passed through otherwise.


threshold_reset(data, threshold=1, window=0.5)

Same as threshold_subtract, except that the potential is reset, rather than subtracted. In other words, only one output spike is possible.