{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST Example\n", "\n", "This tutorial walks you through the different steps involved in building a spiking neural network in `sinabs`.\n", "\n", "Lets start by installing all the necessary packages. " ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: You are using pip version 20.2.1; however, version 20.2.3 is available.\n", "You should consider upgrading via the '/home/martino/.pyenv/versions/3.7.5/bin/python3.7 -m pip install --upgrade pip' command.\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# To keep this tutorial clean and succint, we are directing the output of the `pip` command to a file `install_log`. \n", "# You might want to get rid of the output redirection if you want to make sure there are no errors.\n", "# %pip install -r mnist-requirements.txt\n", "\n", "%pip install -r mnist-requirements.txt > install_log" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a model in `sinabs`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define a `PyTorch` Model\n", "\n", "`sinabs` is a `PyTorch` based library so. So we start by simply defining our model in `PyTorch`. In this example we instantiate a `nn.Sequential` model with 3 `Conv2d` layers and two dense (`nn.Linear`) layers." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "ann = nn.Sequential(\n", " nn.Conv2d(1, 20, 5, 1, bias=False),\n", " nn.ReLU(),\n", " nn.AvgPool2d(2,2),\n", " nn.Conv2d(20, 32, 5, 1, bias=False),\n", " nn.ReLU(),\n", " nn.AvgPool2d(2,2),\n", " nn.Conv2d(32, 128, 3, 1, bias=False),\n", " nn.ReLU(),\n", " nn.AvgPool2d(2,2),\n", " nn.Flatten(),\n", " nn.Linear(128, 500, bias=False),\n", " nn.ReLU(),\n", " nn.Linear(500, 10, bias=False),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define DataSet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the title of this tutorial states, we will train the above model for `MNIST` digit classification task. We borrow the `Dataset` definition from torchvision. Here since we intend to do a spiking neural network simulation, we override this `Dataset` to also *optinally* return a `spike raster` instead of an image. \n", "\n", "In this implementation of the `Dataset` we use *rate coding* to generate a series of spikes at each pixel of the image proportional to it's gray level." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from PIL import Image\n", "from torchvision import datasets\n", "\n", "class MNIST_Dataset(datasets.MNIST):\n", "\n", " def __init__(self, root, train = True, spiking=False, tWindow=100):\n", " datasets.MNIST.__init__(self, root, train=train, download=True)\n", " self.spiking=spiking\n", " self.tWindow = tWindow\n", "\n", "\n", " def __getitem__(self, index):\n", " img, target = self.data[index], self.targets[index]\n", "\n", " if self.spiking:\n", " img = (np.random.rand(self.tWindow, 1, *img.size()) < img.numpy()/255.0).astype(float)\n", " img = torch.from_numpy(img).float()\n", " else:\n", " # Convert PIL image to tensor\n", " img = torch.from_numpy(img.numpy()).float()\n", " img.unsqueeze_(0)\n", "\n", " return img, target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start by first training the `ann` over the `MNIST` image dataset. **Note** here that we are not yet using spiking input (`spiking=False`). This is vanilla training for starndard image classification." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "# Define test dataset loader\n", "train_loader = DataLoader(\n", " MNIST_Dataset('./data', train=True, spiking=False),\n", " batch_size=128, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We iterate over our data loader `trian_loader` and train our parameters using the `Adam` optimizer with a learning rate of `1e-4`. Since the last layer in our network has no specific activation function defined, `cross_entropy` loss is a good candidate to train our network." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "import tqdm\n", "import torch\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "\n", "try:\n", " # Load a pre-trained model to save time if you have already have one.\n", " ann.load_state_dict(torch.load(\"mnist_params.pt\"))\n", "except:\n", " # Train the model\n", "\n", " ann.train()\n", "\n", " optim = torch.optim.Adam(ann.parameters(), lr=1e-4)\n", "\n", " n_epochs = 3\n", "\n", " for n in tqdm.notebook.tqdm(range(n_epochs)):\n", " pbar = tqdm.notebook.tqdm(train_loader)\n", " # Iterate over data\n", " for data, target in pbar:\n", " data, target = data.to(device), target.to(device)\n", " output = ann(data)\n", " optim.zero_grad()\n", "\n", " # Add loss to the total loss\n", " loss = F.cross_entropy(output, target)\n", "\n", " # Propagate loss backwards\n", " loss.backward()\n", "\n", " # Update weights\n", " optim.step()\n", "\n", " # get the index of the max log-probability\n", " pred = output.argmax(dim=1, keepdim=True)\n", "\n", " # Compute the total correct predictions\n", " correct = pred.eq(target.view_as(pred)).sum().item()\n", "\n", " pbar.set_postfix({\"loss\": loss.item(), \"accuracy\": correct/(len(target))})\n", " \n", " # Save model parameters\n", " torch.save(ann.state_dict(), \"mnist_params.pt\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training this model on `MNIST` is fairly straight forward and you should reach accuracies of around `>98%` within a small number of epochs. In the script above we only train for 3 epochs!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to test the accuracy of our model, we first define a convenience method to test and report its performance." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "# Convenience method to test the model\n", "def test(model, data_loader, num_batches=None):\n", " model.eval()\n", " correct = 0\n", " batch_count = 0\n", "\n", " with torch.no_grad():\n", " # Iterate over data\n", " pbar = tqdm.notebook.tqdm(data_loader)\n", " for data, target in pbar:\n", " if data_loader.dataset.spiking:\n", " if len(data.size()) > 4:\n", " warnings.warn(\"Warning: Batch size needs to be 1, only first sample used.\", stacklevel=2)\n", " data = data[0]\n", " target = target[0]\n", " output = model(data)\n", " if data_loader.dataset.spiking:\n", " output = output.sum(0).squeeze().unsqueeze(0)\n", " target = target.unsqueeze(0)\n", " \n", " # get the index of the max log-probability\n", " pred = output.argmax(dim=1, keepdim=True)\n", " # Compute the total correct predictions\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", "\n", " batch_count += 1\n", " if (batch_count*data_loader.batch_size)%500 == 0:\n", " pbar.set_postfix({\"Accuracy\" : correct/(batch_count*data_loader.batch_size)})\n", " if num_batches:\n", " if num_batches <= batch_count: break;\n", "\n", " # Total samples:\n", " num_data = (batch_count*data_loader.batch_size)\n", "\n", " print(f'Test set: Accuracy: {correct}/{num_data} ({100. * correct / num_data}%)\\n'.format(correct, num_data,\n", " ))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let us test our model (`ann`) on our test dataset to check its preformance. Once again we do this by first defining a dataloader." ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "# Define test dataset loader\n", "test_loader = torch.utils.data.DataLoader(\n", " MNIST_Dataset('./data', train=False, spiking=False),\n", " batch_size=5, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now pass our model (`ann`) and the dataloader (`test_loader`) to our test function." ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c508d6aff2fa4773be4312e3c2a92513", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Test set: Accuracy: 9843/10000 (98.43%)\n", "\n" ] } ], "source": [ "test(ann, test_loader)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that we now have a good model (`98%` accuracy) to perform MNIST hand written digit classification. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model conversion to SNN\n", "\n", "Up until this point we have only operated on images using standard CNN architectures. Now we look at how to build an equivalent spiking convolutional neural network (`SCNN`).\n", "\n", "`sinabs` has a handy method for this. Given a standard CNN model, the `from_model` method in `sinabs` that converts it into a spiking neural network. It is a *one liner*! " ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "from sinabs.from_torch import from_model\n", "\n", "input_shape = (1, 28, 28)\n", "\n", "sinabs_model = from_model(ann, input_shape=input_shape, add_spiking_output=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that this method takes two more parameters in addition to the model to be converted.\n", "\n", "`input_shape` is needed in order to instantiate a SNN with the appropriate number of neurons because unline traditional CNNs, SNNs are *stateful*.\n", "\n", "`add_spiking_output` is a boolean flag to specify whether or not to add a spiking layer as the last layer in the network. This ensure that both the input and output to our network are of the form of `spikes`.\n", "\n", "Let us now look at the generated SCNN. You should see that the only major difference is that the `ReLU` layers are replace by `SpikingLayer`." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), bias=False)\n", " (1): SpikingLayer()\n", " (2): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", " (3): Conv2d(20, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)\n", " (4): SpikingLayer()\n", " (5): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", " (6): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", " (7): SpikingLayer()\n", " (8): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", " (9): Flatten()\n", " (10): Linear(in_features=128, out_features=500, bias=False)\n", " (11): SpikingLayer()\n", " (12): Linear(in_features=500, out_features=10, bias=False)\n", " (Spiking output): SpikingLayer()\n", ")" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sinabs_model.spiking_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model validation in sinabs simulaiton" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets test our SCNN model to verify whether the network is infact \"equivalent\" to the CNN model in terms of its performance. As we did previously we start by defining a data loader (this time is is going to produce spikes, `spiking=True`) and then pass it to our test method." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "# Time window per sample\n", "tWindow = 200 # ms (or) time steps\n", "\n", "# Define test dataset loader\n", "test_spike_loader = torch.utils.data.DataLoader(\n", " MNIST_Dataset('./data', train=False, spiking=True, tWindow=tWindow),\n", " batch_size=1, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the spiking simulations are significantly slower on a PC, we are going to limit our test to 200 samples here. You can ofcourse test it on the entire 10k samples if you want to verify that it infact works." ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a03bc53d26f94af5a2eda58b30779794", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/martino/.pyenv/versions/3.7.5/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: Warning: Batch size needs to be 1, only first sample used.\n", " \"\"\"Entry point for launching an IPython kernel.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test set: Accuracy: 193/200 (96.5%)\n", "\n" ] } ], "source": [ "test(sinabs_model, test_spike_loader, num_batches=200)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that this auto-generated spiking (`sinabs_model`) network's performance is close to that of the `ann`! Yay!\n", "\n", "You would have noticed a free parameter that was added `tWindow`. This is a critical parameter that determines whether or not your SNN is going to work well. The longer `tWindow` is, the more spikes we produce as input and the better the performance of the network is going to be. Feel free to experiment with this parameter and see how this changes your network performance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualisation of specific example" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "# Get one sample from the dataloader\n", "for img, label in test_spike_loader:\n", " break\n", "img = img[0] # Img should now have dimensions [Time, Channel, Height, Width]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets visualize this data, just so we know what to expect. We can do this by collapsing the time dimension of the spike raster returned by the dataloader." ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAANzklEQVR4nO3dbYxc5XnG8etivX4HBWPqumBqcKlaCylOtLHbgCJa1MS4aoG0RbZa5Ei0G1VQJRVSS8mH0KqtaBKIoqQNWoIVp0pIiIBgqYSEWkEoFbJYUwM2kNpBRthdbIMTjHH9fvfDHtAG9jyznnf7/v+k1cyee86cW8e+9pyZZ848jggBOPOd1esGAHQHYQeSIOxAEoQdSIKwA0lM6+bGpntGzNScbm4SSOWw3tLROOLJai2F3fZKSV+SNCDpaxFxR+nxMzVHK3xVK5sEULApNtbWmj6Ntz0g6V8lXS1pqaQ1tpc2+3wAOquV1+zLJe2IiJci4qikb0u6pj1tAWi3VsJ+gaRXJvy+q1r2C2wP2x61PXpMR1rYHIBWdPzd+IgYiYihiBga1IxObw5AjVbCvlvSogm/X1gtA9CHWgn7U5IutX2x7emSVkva0J62ALRb00NvEXHc9s2SfqDxobd1EbGtbZ0BaKuWxtkj4hFJj7SpFwAdxMdlgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0m0NGWz7Z2S3pR0QtLxiBhqR1MA2q+lsFd+JyJea8PzAOggTuOBJFoNe0j6oe3Ntocne4DtYdujtkeP6UiLmwPQrFZP46+IiN22f0nSY7ZfjIgnJj4gIkYkjUjSOZ4XLW4PQJNaOrJHxO7qdq+khyQtb0dTANqv6bDbnmP77LfvS/qopK3tagxAe7VyGr9A0kO2336eb0XEo23pCkDbNR32iHhJ0vvb2AuADmLoDUiCsANJEHYgCcIOJEHYgSTacSFMCmO3fLi2Nu9j/1tcd9+bc4r1QwdmFuuLv1P+mzz4xtHamp98prgu8uDIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM4+RbcN31dbW332zzq78Y+Vy5uP1I+zf2XPVcV1F896vVjfdfjcYv31I7OL9V+ZdaC29r7BQ8V1T4aL9X1Hzy7WdxyYX1/8wvnFdaf/YLRYPx1xZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBzRvUlazvG8WOHyuG+/euuPVtTW9n6o/Ddz/pbyPv7Zrze4Xv2D5XH8z132YG1t5ezylFvfe2tusf4Hs+vHySVpwOXeT8TJYr3kPw6Ve/v92QeL9ZOq3++/+fifF9dd8qf/Xaz3q02xUQdi/6QfUODIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcD37FM15YFNt7eIHWnvu8lXZjX35kqtra/+0bEFx3XM2vVKs3zV0YbF++H0DxfqMA/Xj7LNePVxcd9qO8vfxL9v8ULG+cGBWbW3OaH3tTNXwyG57ne29trdOWDbP9mO2t1e35W84ANBzUzmN/7qkle9adqukjRFxqaSN1e8A+ljDsEfEE5L2v2vxNZLWV/fXS7q2vW0BaLdmX7MviIix6v6rkmpfGNoeljQsSTNV/r4yAJ3T8rvxMX4lTe0VBxExEhFDETE0qBmtbg5Ak5oN+x7bCyWput3bvpYAdEKzYd8gaW11f62kh9vTDoBOafia3fZ9kq6UNN/2LkmflXSHpPtt3yjpZUnXd7JJlB1/aWdtbXahJknHGzz3rN3lse5OjlaP3fThYv2CgfJ7QN89eF79ut/aUVz3RLF6emoY9ohYU1M6Pb+FAkiKj8sCSRB2IAnCDiRB2IEkCDuQBJe4omemXbK4WP/O33y+WB/wnGL9zn9ZXVubt+fJ4rpnIo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+zomRf/6peL9SXTyhfQ/vRYecrm+Zt/XltrfiLp0xdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2dNSh61bU1jb98Z0N1i6Ps9/4l39drM945qkGz58LR3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJxdnTUrpX1V47PHyh/7/uf7byyWJ/1+LZiPeM16yUNj+y219nea3vrhGW3295te0v1s6qzbQJo1VRO478uaeUky78YEcuqn0fa2xaAdmsY9oh4QtL+LvQCoINaeYPuZtvPVqf559Y9yPaw7VHbo8d0pIXNAWhFs2H/qqQlkpZJGpNUe0VDRIxExFBEDA1qRpObA9CqpsIeEXsi4kREnJR0j6Tl7W0LQLs1FXbbCyf8ep2krXWPBdAfGo6z275P0pWS5tveJemzkq60vUxSSNop6ZOdaxH97KyZM4v1333/C7W1sePl733f/Q+XFuvTD3G9+qloGPaIWDPJ4ns70AuADuLjskAShB1IgrADSRB2IAnCDiTBJa5oyfZ/XFasf/+iu2trH99xXXHd6Y8ytNZOHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2VF0ZNWHivWfrPm3Yn3LkWO1tTf+/qLiutO0r1jHqeHIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6e3LRLFhfrn/nKPcX6gMvHiz95sv5bxi/ZuLm4LtqLIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+xnO08r/xL/9vReL9ctnnCzW7/75omL91/75cG2t/Mxot4ZHdtuLbP/I9vO2t9n+VLV8nu3HbG+vbs/tfLsAmjWV0/jjkm6JiKWSfkvSTbaXSrpV0saIuFTSxup3AH2qYdgjYiwinq7uvynpBUkXSLpG0vrqYeslXduhHgG0wSm9Zre9WNIHJG2StCAixqrSq5IW1KwzLGlYkmZqdtONAmjNlN+Ntz1X0gOSPh0RBybWIiIkxWTrRcRIRAxFxNCgZrTULIDmTSnstgc1HvRvRsSD1eI9thdW9YWS9namRQDt0PA03rYl3SvphYi4a0Jpg6S1ku6obh/uSIdoyVlLFhfrf3fe/S09/8iX/7BYP3/rky09P9pnKq/ZL5d0g6TnbG+plt2m8ZDfb/tGSS9Lur4jHQJoi4Zhj4gfS3JN+ar2tgOgU/i4LJAEYQeSIOxAEoQdSIKwA0lwiesZ4KzLfqO29okHv9/Sc182cnOxftHdjKOfLjiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLOfAXbcUP/FvtfPfaO47rGY9AuG3rHwv46UN95gffQPjuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7KeBQ9etKNYfXf352tqJKE+5NeiBcv3gsWIdpw+O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQxFTmZ18k6RuSFkgKSSMR8SXbt0v6C0n7qofeFhGPdKrRzF5fWh4LXzI4t7Z2Ik4W1338/8p/733sRLHO1eynj6l8qOa4pFsi4mnbZ0vabPuxqvbFiPhC59oD0C5TmZ99TNJYdf9N2y9IuqDTjQFor1N6zW57saQPSNpULbrZ9rO219me9LuRbA/bHrU9ekwNvuIIQMdMOey250p6QNKnI+KApK9KWiJpmcaP/HdOtl5EjETEUEQMDWpG6x0DaMqUwm57UONB/2ZEPChJEbEnIk5ExElJ90ha3rk2AbSqYdhtW9K9kl6IiLsmLF844WHXSdra/vYAtMtU3o2/XNINkp6zvaVadpukNbaXaXz0ZaekT3agP0g6cn55+KzknjcWFesbPlI/3bMkxWvbmt42+stU3o3/sSRPUmJMHTiN8Ak6IAnCDiRB2IEkCDuQBGEHkiDsQBKOLk65e47nxQpf1bXtAdlsio06EPsnGyrnyA5kQdiBJAg7kARhB5Ig7EAShB1IgrADSXR1nN32PkkvT1g0X9JrXWvg1PRrb/3al0RvzWpnb78aEedPVuhq2N+zcXs0IoZ61kBBv/bWr31J9NasbvXGaTyQBGEHkuh12Ed6vP2Sfu2tX/uS6K1ZXemtp6/ZAXRPr4/sALqEsANJ9CTstlfa/ontHbZv7UUPdWzvtP2c7S22R3vcyzrbe21vnbBsnu3HbG+vbiedY69Hvd1ue3e177bYXtWj3hbZ/pHt521vs/2panlP912hr67st66/Zrc9IOl/JP2epF2SnpK0JiKe72ojNWzvlDQUET3/AIbtj0g6KOkbEXFZtexzkvZHxB3VH8pzI+Jv+6S32yUd7PU03tVsRQsnTjMu6VpJn1AP912hr+vVhf3WiyP7ckk7IuKliDgq6duSrulBH30vIp6QtP9di6+RtL66v17j/1m6rqa3vhARYxHxdHX/TUlvTzPe031X6KsrehH2CyS9MuH3Xeqv+d5D0g9tb7Y93OtmJrEgIsaq+69KWtDLZibRcBrvbnrXNON9s++amf68VbxB915XRMQHJV0t6abqdLUvxfhrsH4aO53SNN7dMsk04+/o5b5rdvrzVvUi7LslTZxt8MJqWV+IiN3V7V5JD6n/pqLe8/YMutXt3h73845+msZ7smnG1Qf7rpfTn/ci7E9JutT2xbanS1otaUMP+ngP23OqN05ke46kj6r/pqLeIGltdX+tpId72Msv6JdpvOumGVeP913Ppz+PiK7/SFql8XfkfyrpM73ooaavSyQ9U/1s63Vvku7T+GndMY2/t3GjpPMkbZS0XdJ/SprXR739u6TnJD2r8WAt7FFvV2j8FP1ZSVuqn1W93neFvrqy3/i4LJAEb9ABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL/D54eEkh/BQOWAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "\n", "plt.imshow(img.sum(0)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now take this data (including the time dimension), and pass it to the Sinabs SNN model." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "snn_output = sinabs_model(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us now display the output in time." ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, 'Time')" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAS7klEQVR4nO3dfZBl9V3n8fcnw2RGEIwQjAmgkCyyYgiBdMiDCaUhGsAIymYjqbirbmqnLONu2LhlobEUtfwjalKuWz7UrGFFjcSVQEm5CU8ak7VqAxnGmTAwjCEYAsPDJBiFEAsG+PrHPYOXtrvn9u3+3XOH835V3ep7T597ft/63dP30+fxl6pCkjRcz+u7AElSvwwCSRo4g0CSBs4gkKSBMwgkaeAO67uAcc/PptrMEX2XIUmHjEf5yper6ti1LGOugmAzR/CanNN3GZJ0yLiprrpnrctw15AkDZxBIEkDZxBI0sAZBJI0cAaBJA2cQSBJA9c0CJK8J8muJLcnuaRlW5Kk6TQLgiQvB/4zcBZwOvDWJP+mVXuSpOm03CL4duDmqvpaVT0JfBK4qGF7kqQptAyCXcAbkxyT5HDgfOCExTMl2ZJkW5Jt+3m8YTmSpKU0u8VEVe1O8n7gBuAxYAfw1BLzbQW2AhyVox0uTZJmrOnB4qr6UFW9qqrOBr4C/G3L9iRJq9f0pnNJvqmq9iX5FkbHB17bsj1J0uq1vvvoR5McA+wH3l1V/9C4PUnSKjUNgqp6Y8vlS5LWziuLJWngDAJJGjiDQJIGziCQpIEzCCRp4AwCSRo4g0CSBs4gkKSBMwgkaeAMAkkaOINAkgbOIJCkgTMIJGngmgZBkv+W5PYku5JcmWRzy/YkSavXLAiSHAf8V2Chql4ObAAubtWeJGk6rXcNHQZ8XZLDgMOB+xu3J0lapWZBUFV7gV8Hvgg8APxjVd2weL4kW5JsS7JtP4+3KkeStIyWu4a+EbgQOAl4CXBEkh9ePF9Vba2qhapa2MimVuVIkpbRctfQm4G/q6ovVdV+4Grg9Q3bkyRNoWUQfBF4bZLDkwQ4B9jdsD1J0hRaHiO4GbgK2A7c1rW1tVV7kqTpHNZy4VX1C8AvtGxDkrQ2XlksSQNnEEjSwBkEkjRwBoEkDZxBIEkDZxAs4fr7d3L9/Tv7LmNV+qz5UOirWfbPPK4/81jTas26/nnqs9a1GASSNHAGgSQNnEEgSQNnEEjSwBkEkjRwBoEkDZxBIEkD13KEslOS7Bh7PJLkklbtSZKm0+w21FW1B3glQJINwF7gmlbtSZKmM6tdQ+cAn6+qe2bUniRpQk0HphlzMXDlUr9IsgXYArCZw2dUjiTpgOZbBEmeD1wA/OlSv6+qrVW1UFULG9nUuhxJ0iKz2DV0HrC9qh6aQVuSpFWaRRC8g2V2C0mS+tc0CJIcAXwPcHXLdiRJ02t6sLiqHgOOadmGJGltvLJYkgbOIJCkgTMIJGngDAJJGjiDQJIGLlXVdw3POCpH12tyTt9lSNIh46a66taqWljLMtwikKSBMwgkaeAMAkkaOINAkgbOIJCkgTMIJGngWt999AVJrkpyZ5LdSV7Xsj1J0uq1HqryfwDXVdXbupHKHItSkuZMsyBI8g3A2cCPAlTVE8ATrdqTJE2n5a6hk4AvAf87yd8k+b1uoJpnSbIlybYk2/bzeMNyJElLaRkEhwFnAr9TVWcAjwGXLp7JweslqV8tg+A+4L6qurl7fRWjYJAkzZFmQVBVDwL3Jjmlm3QOcEer9iRJ02l91tB/AT7cnTF0N/BjjduTJK1S68HrdwBruj2qJKktryyWpIEzCCRp4AwCSRo4g0CSBu6gB4uTnAb82+7l7qra1bYkSdIsLRsE3b2C/gw4AfgsEOC0JF8ELqyqR2ZToiSppZV2Df0ysA04uap+sKp+ADgZ+AzwKzOoTZI0AyvtGnoz8IqqevrAhKp6OsnPArc1r0ySNBMrbRE8UVVPLp7YTfM2oZL0HLHSFsHmJGcwOjYwLuBtQiXpuWKlIHgQ+OAKv5MkPQcsGwRV9V0zrEOS1JOVTh+9aKU3VtXV61+OJGnWVto19P0r/K6AgwZBki8AjwJPAU9WlXcilaQ5s9KuofUaO+C7q+rL67QsSdI6815DkjRwrYOggBuS3Jpky1IzJNmSZFuSbfu9PEGSZq71UJVvqKq9Sb4JuDHJnVX1qfEZqmorsBXgqBxdjeuRJC0yURAkeT1w4vj8VfUHB3tfVe3tfu5Lcg1wFvCpld8lSZqlSW5D/YfAy4AdjM7+gdEunxWDIMkRwPOq6tHu+fcCv7SmaiVJ626SLYIF4NSqWu1umxcB1yQ50M4fV9V1q1yGJKmxSYJgF/DNwAOrWXBV3Q2cPk1RkqTZmSQIXgjckeQWxu46WlUXNKtKkjQzkwTBZa2LkCT156BBUFWfTPIi4NXdpFuqal/bsiRJs3LQC8qSvB24Bfj3wNuBm5O8rXVhkqTZmGTX0PuAVx/YCkhyLHATcFXLwiRJszHJLSaet2hX0MMTvk+SdAiYZIvguiTXA1d2r38I+Fi7kiRJs7RiEGR0NdhvMjpQ/IZu8taquqZ1YZKk2VgxCKqqknysqk5jgoFoJEmHnkn29W9P8uqDzyZJOhRNcozgNcA7k9wDPAaE0cbCK5pWJkmaiUmC4C3Nq5Ak9WaSIHCwGEl6DpskCP4vozAIsBk4CdgDfMckDSTZAGwD9lbVW6esU5LUyCT3Gjpt/HWSM4GfWEUb7wF2A0etrjRJ0iys+grhqtrO6ADyQSU5Hvg+4PdW244kaTYmGaryvWMvnwecCdw/4fJ/A/hp4MgVlr8F2AKwmcMnXKwkab1MskVw5NhjE6NjBhce7E1J3grsq6pbV5qvqrZW1UJVLWxk0wTlSJLW0yTHCH4RIMnhVfW1VSz7O4ELkpzP6CDzUUn+qKp+eLpSJUktTDIeweuS3AHc2b0+PclvH+x9VfUzVXV8VZ0IXAz8pSEgSfNnkl1Dv8HoorKHAapqJ3B2w5okSTM0yXUEVNW9oxuRPuOp1TRSVX8F/NVq3iNJmo1JguDeJK8HKslG/uW6AEnSc8Aku4Z+HHg3cBywF3hl91qS9BwwyVlDXwbeOYNaJEk9WDYIkvz8Cu+rqvrlBvVIkmZspS2Cx5aYdgTwLuAYwCCQpOeAZYOgqj5w4HmSIxkdJP4x4CPAB5Z7n55brr9/JwBvecnpPVcircx1dXoHG7z+aOC9jI4RXAGcWVVfmUVhkqTZWOkYwa8BFwFbgdOq6qszq0qSNDMrnT76U8BLgJ8D7k/ySPd4NMkjsylPktTaSscIVj1WgSTp0OOXvSQNnEEgSQNnEEjSwDULgiSbk9ySZGeS25P8Yqu2JEnTm+g21FN6HHhTVX21u2vpXyf5eFV9umGbkqRVahYEVVXAgWsPNnaPatWeJGk6TY8RJNmQZAewD7ixqm5eYp4tSbYl2bafx1uWI0laQtMgqKqnquqVwPHAWUlevsQ8W6tqoaoWNrKpZTmSpCXM5KyhqvoH4BPAubNoT5I0uZZnDR2b5AXd868Dvge4s1V7kqTptDxr6MXAFUk2MAqc/1NVf96wPUnSFFqeNfRZ4IxWy5ckrQ+vLJakgTMIJGngDAJJGjiDQJIGziCQpIFrefroqn3bK74Gt/Vdhca95SWn99Lu9ffv7LV9tXH9/TubfaauK9Nzi0CSBs4gkKSBMwgkaeAMAkkaOINAkgbOIJCkgWt5G+oTknwiyR3d4PXvadWWJGl6La8jeBL4qaranuRI4NYkN1bVHQ3blCStUrMtgqp6oKq2d88fBXYDx7VqT5I0nZkcI0hyIqOxCVYcvP5LDz81i3IkSWOaB0GSrwc+ClxSVY8s/v344PXHHrOhdTmSpEWaBkGSjYxC4MNVdXXLtiRJ02l51lCADwG7q+qDrdqRJK1Nyy2C7wT+A/CmJDu6x/kN25MkTaHl4PV/DaTV8iVJ68MriyVp4AwCSRo4g0CSBs4gkKSBMwgkaeAMAkkauFRV3zU846gcXa/JOX2XIUmHjJvqqluramEty3CLQJIGziCQpIEzCCRp4AwCSRo4g0CSBs4gkKSBazkeweVJ9iXZ1aoNSdLatdwi+H3g3IbLlyStg2ZBUFWfAv6+1fIlSeuj2cA0k0qyBdgCsJnDe65Gkoan94PFVbW1qhaqamEjm/ouR5IGp/cgkCT1yyCQpIFrefrolcD/B05Jcl+Sd7VqS5I0vWYHi6vqHa2WLUlaP+4akqSBMwgkaeAMAkkaOINAkgbOIJCkgTMIJGngDAJJGjiDQJIGziCQpIEzCCRp4AwCSRo4g0CSBs4gkKSBaxoESc5NsifJXUkubdmWJGk6Lccj2AD8FnAecCrwjiSntmpPkjSdllsEZwF3VdXdVfUE8BHgwobtSZKm0DIIjgPuHXt9XzftWZJsSbItybb9PN6wHEnSUno/WFxVW6tqoaoWNrKp73IkaXBaBsFe4ISx18d30yRJc6RlEHwGODnJSUmeD1wMXNuwPUnSFFoOXv9kkp8Ergc2AJdX1e2t2pMkTadZEABU1ceAj7VsQ5K0Nr0fLJYk9csgkKSBMwgkaeAMAkkaOINAkgYuVdV3Dc9I8iiwp+86DuKFwJf7LmIC1rm+rHN9Wef6OaWqjlzLApqePjqFPVW10HcRK0mybd5rBOtcb9a5vqxz/STZttZluGtIkgbOIJCkgZu3INjadwETOBRqBOtcb9a5vqxz/ay5xrk6WCxJmr152yKQJM2YQSBJAzcXQZDk3CR7ktyV5NK+6zkgyQlJPpHkjiS3J3lPN/2yJHuT7Oge589BrV9IcltXz7Zu2tFJbkzyue7nN/Zc4yljfbYjySNJLpmH/kxyeZJ9SXaNTVuy/zLym936+tkkZ/ZY468lubOr45okL+imn5jkn8b69HdnUeMKdS77GSf5ma4v9yR5S891/slYjV9IsqOb3md/Lvc9tH7rZ1X1+mA0VsHngZcCzwd2Aqf2XVdX24uBM7vnRwJ/C5wKXAb8977rW1TrF4AXLpr2q8Cl3fNLgff3Xeeiz/1B4FvnoT+Bs4EzgV0H6z/gfODjQIDXAjf3WOP3Aod1z98/VuOJ4/PNQV8u+Rl3f087gU3ASd13wYa+6lz0+w8APz8H/bnc99C6rZ/zsEVwFnBXVd1dVU8AHwEu7LkmAKrqgara3j1/FNgNHNdvVatyIXBF9/wK4Af6K+VfOQf4fFXd03chAFX1KeDvF01erv8uBP6gRj4NvCDJi/uosapuqKonu5efZjQkbK+W6cvlXAh8pKoer6q/A+5i9J3Q3Ep1JgnwduDKWdSykhW+h9Zt/ZyHIDgOuHfs9X3M4ZdtkhOBM4Cbu0k/2W12Xd73LpdOATckuTXJlm7ai6rqge75g8CL+iltSRfz7D+yeetPWL7/5nWd/U+M/hM84KQkf5Pkk0ne2FdRY5b6jOe1L98IPFRVnxub1nt/LvoeWrf1cx6CYO4l+Xrgo8AlVfUI8DvAy4BXAg8w2oTs2xuq6kzgPODdSc4e/2WNthnn4lzhjMawvgD4027SPPbns8xT/y0lyfuAJ4EPd5MeAL6lqs4A3gv8cZKj+qqPQ+AzXuQdPPsfld77c4nvoWesdf2chyDYC5ww9vr4btpcSLKRUed/uKquBqiqh6rqqap6GvhfzGhTdiVVtbf7uQ+4hlFNDx3YJOx+7uuvwmc5D9heVQ/BfPZnZ7n+m6t1NsmPAm8F3tl9IdDtanm4e34ro33v39ZXjSt8xnPVlwBJDgMuAv7kwLS++3Op7yHWcf2chyD4DHBykpO6/xQvBq7tuSbgmf2EHwJ2V9UHx6aP72/7QWDX4vfOUpIjkhx54DmjA4i7GPXjj3Sz/QjwZ/1U+K8867+teevPMcv137XAf+zOzngt8I9jm+gzleRc4KeBC6rqa2PTj02yoXv+UuBk4O4+auxqWO4zvha4OMmmJCcxqvOWWde3yJuBO6vqvgMT+uzP5b6HWM/1s4+j4EscFT+f0ZHwzwPv67uesbrewGhz67PAju5xPvCHwG3d9GuBF/dc50sZnXmxE7j9QB8CxwB/AXwOuAk4eg769AjgYeAbxqb13p+MgukBYD+jfarvWq7/GJ2N8Vvd+nobsNBjjXcx2h98YP383W7ef9etCzuA7cD399yXy37GwPu6vtwDnNdnnd303wd+fNG8ffbnct9D67Z+eosJSRq4edg1JEnqkUEgSQNnEEjSwBkEkjRwBoEkDZxBoEFKcszYnSQfHLsz5leT/Hbf9Umz5OmjGrwklwFfrapf77sWqQ9uEUhjknxXkj/vnl+W5Iok/y/JPUkuSvKrGY37cF132T9JXtXdiOzWJNfP4k6k0noyCKSVvQx4E6Ob5P0R8ImqOg34J+D7ujD4n8DbqupVwOXAr/RVrDSNw/ouQJpzH6+q/UluYzSYznXd9NsYDVZyCvBy4MbRLWHYwOi2BdIhwyCQVvY4QFU9nWR//ctBtacZ/f0EuL2qXtdXgdJauWtIWps9wLFJXgej2wUn+Y6ea5JWxSCQ1qBGw6u+DXh/kp2M7gz5+l6LklbJ00claeDcIpCkgTMIJGngDAJJGjiDQJIGziCQpIEzCCRp4AwCSRq4fwaj5ZIHwYhQQAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.pcolormesh(snn_output.T.detach())\n", "\n", "plt.ylabel(\"Neuron ID\")\n", "plt.yticks(np.arange(10) + 0.5, np.arange(10));\n", "plt.xlabel(\"Time\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the majority of spikes are emitted by the output neuron corresponding to the digit 7, which is a correct inference." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## More analysis of the SNN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Synaptic Operations\n", "\n", "One of the factors you might be interested in finding out about your model is the total number of synaptic operations required for an inference. The `Network` class provides a handy method to compute the total number of synaptic operations for the last inference performed by the model.\n", "\n", "For instance, in order to look at the total synaptic operations per layer for recognition of image `7` we run the `get_synops` method." ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Fanout_PrevInSynOpsSynOps/sTime_window
Layer
0500.014452.007226000.036130000.0200.0
3800.037321.2529857000.0149285000.0200.0
61152.06589.007590528.037952640.0200.0
10500.0871.25435625.02178125.0200.0
1210.01301.0013010.065050.0200.0
\n", "
" ], "text/plain": [ " Fanout_Prev In SynOps SynOps/s Time_window\n", "Layer \n", "0 500.0 14452.00 7226000.0 36130000.0 200.0\n", "3 800.0 37321.25 29857000.0 149285000.0 200.0\n", "6 1152.0 6589.00 7590528.0 37952640.0 200.0\n", "10 500.0 871.25 435625.0 2178125.0 200.0\n", "12 10.0 1301.00 13010.0 65050.0 200.0" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sinabs_model.get_synops()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Comparing activity of the ANN with SNN" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['1']\n", "['11']\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Generate data to use\n", "for data, label in test_spike_loader:\n", " break;\n", "\n", "cnn_act, spk_act = sinabs_model.plot_comparison(data[0], compute_rate=True, name_list=['1'])\n", "plt.figure()\n", "cnn_act, spk_act = sinabs_model.plot_comparison(data[0], compute_rate=True, name_list=['11'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }