Training by backpropagation through time (BPTT)

BPTT is normally a procedure used while training recurrent neural networks. In the case of spiking networks, even if the network is not recurrent, it has a memory of its previous processing steps through the persistence of membrane potentials. Unlike normal neural networks, spiking networks have an internal state that lasts in time.

This is why BPTT can be used for more precise (but also much more computationally expensive) training in SNNs for sequential tasks. In sinabs, backpropagation in the spiking network is accomplished through a surrogate gradient method, since the spiking nonlinearity is not differentiable.

In this notebook, we will train a spiking network directly (without training an analog network first), on the Sequential MNIST task. In Sequential MNIST, a network is shown the 28x28-pixel MNIST digits one row after the other. The input to the network is a single row of 28 pixels, followed by the second one, etc, until all 28 rows are shown. At this point, the network makes a prediction on the digit label.

First, we define the MNIST dataset. Note that pixel values are between 0 and 1. We turn those values into probabilities of spiking.

[1]:
from torchvision import datasets
import torch

torch.manual_seed(0)

class MNIST_Dataset(datasets.MNIST):
    def __init__(self, root, train=True, single_channel=False):
        datasets.MNIST.__init__(self, root, train=train, download=True)
        self.single_channel = single_channel

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = img.float() / 255.

        # default is  by row, output is [time, channels] = [28, 28]
        # OR if we want by single item, output is [784, 1]
        if self.single_channel:
            img = img.reshape(-1).unsqueeze(1)

        spikes = torch.rand(size=img.shape) < img
        spikes = spikes.float()

        return spikes, target
[2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 64

dataset_test = MNIST_Dataset(root="./data/", train=False)
dataloader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=BATCH_SIZE, drop_last=True)

dataset = MNIST_Dataset(root="./data/", train=True)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, drop_last=True)

Training a baseline

We must demonstrate that this task is not solvable with similar accuracy by a memory-less analog network, despite being sequential. Let us then try to train such a baseline.

[3]:
from torch import nn

ann = nn.Sequential(
    nn.Linear(28, 128),
    nn.ReLU(),
    nn.Linear(128, 128),
    nn.ReLU(),
    nn.Linear(128, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.ReLU()
)

Training

[4]:
from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters())

for epoch in range(5):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()

        target = target.unsqueeze(1).repeat([1, 28])
        img = img.reshape([-1, 28])
        target = target.reshape([-1])

        out = ann(img)
#         out = out.sum(1)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())





Testing

[5]:
accs = []

pbar = tqdm(dataloader_test)
for img, target in pbar:

    img = img.reshape([-1, 28])
    out = ann(img)
    out = out.reshape([64, 28, 10])
    out = out.sum(1)

    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target).sum().numpy() / BATCH_SIZE
    accs.append(acc)

print(sum(accs)/len(accs))

0.4740584935897436

Defining a spiking network

We then define a 4-layer fully connected spiking neural network.

[6]:
from sinabs.from_torch import from_model

model = from_model(ann, batch_size=BATCH_SIZE).to(device)
model = model.train()

Training

Here, we begin training. Note that the state of the network must be reset at every iteration.

[7]:
from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(10):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()
        model.reset_states()

        out = model.spiking_model(img.to(device))
        # the output of the network is summed over the 28 time steps (rows)
        out = out.sum(1)
        loss = criterion(out, target.to(device))
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())










Testing

[8]:
accs = []

pbar = tqdm(dataloader_test)
for img, target in pbar:
    model.reset_states()

    out = model(img.to(device))
    out = out.sum(1)

    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target.to(device)).sum().cpu().numpy() / BATCH_SIZE
    accs.append(acc)

print(sum(accs)/len(accs))

0.6699719551282052

This value, although not very high, shows as a proof of concept that the persistent state of the spiking network (the membrane potentials) can be exploited as a short-term memory for solving sequential tasks, provided the training procedure takes it into account.