Training by backpropagation through time (BPTT)¶
BPTT is normally a procedure used while training recurrent neural networks. In the case of spiking networks, even if the network is not recurrent, it has a memory of its previous processing steps through the persistence of membrane potentials. Unlike normal neural networks, spiking networks have an internal state that lasts in time.
This is why BPTT can be used for more precise (but also much more computationally expensive) training in SNNs for sequential tasks. In sinabs, backpropagation in the spiking network is accomplished through a surrogate gradient method, since the spiking nonlinearity is not differentiable.
In this notebook, we will train a spiking network directly (without training an analog network first), on the Sequential MNIST task. In Sequential MNIST, a network is shown the 28x28-pixel MNIST digits one row after the other. The input to the network is a single row of 28 pixels, followed by the second one, etc, until all 28 rows are shown. At this point, the network makes a prediction on the digit label.
First, we define the MNIST dataset. Note that pixel values are between 0 and 1. We turn those values into probabilities of spiking.
from torchvision import datasets import torch torch.manual_seed(0) class MNIST_Dataset(datasets.MNIST): def __init__(self, root, train=True, single_channel=False): datasets.MNIST.__init__(self, root, train=train, download=True) self.single_channel = single_channel def __getitem__(self, index): img, target = self.data[index], self.targets[index] img = img.float() / 255. # default is by row, output is [time, channels] = [28, 28] # OR if we want by single item, output is [784, 1] if self.single_channel: img = img.reshape(-1).unsqueeze(1) spikes = torch.rand(size=img.shape) < img spikes = spikes.float() return spikes, target
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') BATCH_SIZE = 64 dataset_test = MNIST_Dataset(root="./data/", train=False) dataloader_test = torch.utils.data.DataLoader( dataset_test, batch_size=BATCH_SIZE, drop_last=True) dataset = MNIST_Dataset(root="./data/", train=True) dataloader = torch.utils.data.DataLoader( dataset, batch_size=BATCH_SIZE, drop_last=True)
Training a baseline¶
We must demonstrate that this task is not solvable with similar accuracy by a memory-less analog network, despite being sequential. Let us then try to train such a baseline.
from torch import nn ann = nn.Sequential( nn.Linear(28, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 10), nn.ReLU() )
from tqdm.notebook import tqdm criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(ann.parameters()) for epoch in range(5): pbar = tqdm(dataloader) for img, target in pbar: optimizer.zero_grad() target = target.unsqueeze(1).repeat([1, 28]) img = img.reshape([-1, 28]) target = target.reshape([-1]) out = ann(img) # out = out.sum(1) loss = criterion(out, target) loss.backward() optimizer.step() pbar.set_postfix(loss=loss.item())
accs =  pbar = tqdm(dataloader_test) for img, target in pbar: img = img.reshape([-1, 28]) out = ann(img) out = out.reshape([64, 28, 10]) out = out.sum(1) predicted = torch.max(out, axis=1) acc = (predicted == target).sum().numpy() / BATCH_SIZE accs.append(acc) print(sum(accs)/len(accs))
Defining a spiking network¶
We then define a 4-layer fully connected spiking neural network.
from sinabs.from_torch import from_model model = from_model(ann, batch_size=BATCH_SIZE).to(device) model = model.train()
Here, we begin training. Note that the state of the network must be reset at every iteration.
from tqdm.notebook import tqdm criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) for epoch in range(10): pbar = tqdm(dataloader) for img, target in pbar: optimizer.zero_grad() model.reset_states() out = model.spiking_model(img.to(device)) # the output of the network is summed over the 28 time steps (rows) out = out.sum(1) loss = criterion(out, target.to(device)) loss.backward() optimizer.step() pbar.set_postfix(loss=loss.item())
accs =  pbar = tqdm(dataloader_test) for img, target in pbar: model.reset_states() out = model(img.to(device)) out = out.sum(1) predicted = torch.max(out, axis=1) acc = (predicted == target.to(device)).sum().cpu().numpy() / BATCH_SIZE accs.append(acc) print(sum(accs)/len(accs))
This value, although not very high, shows as a proof of concept that the persistent state of the spiking network (the membrane potentials) can be exploited as a short-term memory for solving sequential tasks, provided the training procedure takes it into account.