{ "cells": [ { "cell_type": "markdown", "id": "crucial-desert", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# The Hopfield Network" ] }, { "cell_type": "markdown", "id": "expanded-specification", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "The Hopfield network is an artificial neural network that stores a set of patterns as attractors of its dynamics. The state of the network is represented by a vector $S_i$ for $i = 1, \\cdots, N$, where $N$ is the number of nodes in the network. We will consider nodes that can only take discrete values $\\pm 1$. The dynamics of the network is described by the update rule:\n", "\\begin{equation}\n", "S_i (t+1) = \\operatorname{sgn} \\Big( \\sum_j J_{ij} S_j (t) \\Big)\n", "\\end{equation}\n", "Here the matrix $J_{ij}$ represents the connections between the nodes.\n", "\n", "There are two ways to implement the update of all nodes. One way is to update all nodes together at the same time, which is called a \"synchronous\" update. The other way is to update the nodes one at a time, which is called \"asynchronous\". These are not equivalent because in the synchronous update, the new value of each node depends on the old values of all nodes, whereas in the asynchronous update, even if we sweep over all nodes in each round of updates, the new value of a node will depend on the new values of the nodes that have already been updated in the same round (and the old values of those not yet updated). The latter scheme may be considered more realistic biologically. In order to avoid bias, we will choose a random order to update the nodes, which introduces some stochasticity in the dynamics." ] }, { "cell_type": "markdown", "id": "0d431e68-a89b-42de-a7a0-baa60fb5d76f", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "A pattern is a state of the network where each node takes a specific value $\\xi_i = \\pm 1$. For a given pattern $\\xi_i$ to be stored in the network, it has to be a stable steady state of the dynamics. In that case, if we initialize the network at a state close to the pattern (a \"seed\" state), the network will follow its dynamics and converge to the steady state, thus retrieving the stored pattern. Hopfield found a simple form of $J_{ij}$ that allows the network to store multiple patterns, $\\xi_i^\\mu$ for $\\mu = 1, \\cdots, M$ (think of each $\\xi^\\mu$ as a vector with components $\\xi_i^\\mu$). That is,\n", "\\begin{equation}\n", "J_{ij} = \\frac{1}{N} \\sum_{\\mu=1}^{M} \\xi^\\mu_i \\xi^\\mu_j\n", "\\end{equation}\n", "With this connection matrix, each pattern $\\xi^\\mu$ is a locally stable steady state. Therefore, if we provide a seed to the network that is close to one of the patterns, the network will automatically retrieve the correct full pattern." ] }, { "cell_type": "markdown", "id": "former-brush", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Let us implement the Hopfield network as follows. First, we will write a base class to represent a general network with an arbitrary connection matrix, the dynamics of which is given by the update rule above. Such a network in which every node can be connected to all other nodes is called \"recurrent\" (more precisely, a recurrent network has a connnectivity graph that has cycles)." ] }, { "cell_type": "code", "execution_count": 1, "id": "capable-frequency", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "id": "forbidden-journey", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "class RecurrentNetwork:\n", " \"\"\"\n", " base class of neural network with recurrent connection matrix\n", " \"\"\"\n", "\n", " def __init__(self, N, sync=False):\n", " \"\"\"\n", " declare internal variables.\n", " inputs:\n", " N: int, number of neurons forming the network.\n", " init: 1-d array, initial state of the network.\n", " sync: bool, whether to update synchronously.\n", " \"\"\"\n", " self.N = int(N) # number of neurons\n", " self.state = np.zeros(self.N, dtype=int) # state of neurons, S_i\n", " self.connect = np.zeros((self.N,self.N)) # connection matrix, J_{ij}\n", " self.sync = sync\n", "\n", " def set_network(self, state=None, connect=None):\n", " \"\"\"\n", " set state and connection of the network.\n", " inputs:\n", " state: 1-d array, state of neurons to be set.\n", " connect: 2-d array, connection matrix to be set.\n", " \"\"\"\n", " if state is not None:\n", " self.state = np.asarray(state, dtype=int)\n", " if connect is not None:\n", " self.connect = np.asarray(connect, dtype=float)\n", "\n", " def update1(self, i):\n", " \"\"\"\n", " update one neuron of given index, as well as its effect on others.\n", " inputs:\n", " i: int, index of neuron to be updated.\n", " outputs:\n", " b: bool, whether the state of given neuron is changed.\n", " \"\"\"\n", " s = np.sign(self.input[i]).astype(int)\n", " if s != self.state[i]:\n", " self.state[i] = s\n", " self.input += 2*s*self.connect[:,i]\n", " return True\n", " else:\n", " return False\n", "\n", " def updateN(self, size=None, replace=True):\n", " \"\"\"\n", " update many neurons in random sequential order.\n", " inputs:\n", " size: number of neurons to update in a sweep, default is `self.N`.\n", " replace: whether to sample neurons with replacement.\n", " outputs:\n", " flip_any: bool, whether the state of any neuron is changed.\n", " \"\"\"\n", " if size is None:\n", " size = self.N\n", " seq = np.random.choice(np.arange(self.N), size=size, replace=replace)\n", " self.input = np.dot(self.connect, self.state)\n", " flip_any = False\n", " if self.sync: # update synchronously\n", " new_state = self.act(self.input[seq]).astype(int)\n", " flip_any = np.any(new_state != self.state[seq])\n", " self.state[seq] = new_state\n", " else: # update asynchronously\n", " for i in seq:\n", " flip = self.update1(i)\n", " flip_any = (flip_any or flip)\n", " return flip_any" ] }, { "cell_type": "markdown", "id": "approved-enclosure", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Then we will define a derived class to implement the Hopfield model by specifying the connection matrix. We will also define a method `Hopfield.overlap()` that calculates the overlap between the state of the network and given patterns." ] }, { "cell_type": "code", "execution_count": 3, "id": "operating-croatia", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [], "source": [ "class Hopfield(RecurrentNetwork):\n", " \"\"\"\n", " Hopfield network that stores patterns as attractors.\n", " \"\"\"\n", "\n", " def __init__(self, patterns, **kwargs):\n", " \"\"\"\n", " create Hopfield network from a set of patterns to be stored.\n", " inputs:\n", " patterns: 2-d array, each row is a pattern.\n", " \"\"\"\n", " N = np.shape(patterns)[1] # size of network\n", " RecurrentNetwork.__init__(self, N, **kwargs)\n", " self.p = 0 # number of stored patterns\n", " self.patterns = np.zeros((0,N)) # stored patterns, xi_i^mu\n", " self.store(patterns)\n", "\n", " def store(self, new_patterns):\n", " \"\"\"\n", " store new patterns into the connection matrix.\n", " inputs:\n", " new_patterns: 2-d array, each row is a new pattern.\n", " outputs:\n", " sym: 2-d array, updated symmetric connection matrix.\n", " \"\"\"\n", " new_patterns = np.reshape(new_patterns, (-1,self.N)).astype(float)\n", " self.p += new_patterns.shape[0]\n", " self.patterns = np.vstack([self.patterns, new_patterns])\n", " self.connect += np.dot(new_patterns.T, new_patterns) / self.N\n", " return self.connect\n", "\n", " def overlap(self, mu=None):\n", " \"\"\"\n", " calculate overlap between network state and each pattern.\n", " inputs:\n", " mu: int or 1-d array, index of pattern(s), if None, calculate for all.\n", " outputs:\n", " m_mu: real or 1-d array, overlap for given pattern(s).\n", " \"\"\"\n", " if mu is None:\n", " mu = np.arange(self.p)\n", " m_mu = np.dot(self.state, self.patterns[mu].T) / self.N\n", " return m_mu" ] }, { "cell_type": "markdown", "id": "brief-service", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Let us test the Hopfield network's ability to retrieve stored patterns. First, let us select some patterns for the network to store. To better visualize the process, we will use images as our patterns. These will be black and white images, such that each pixel takes a value 0 (black) or 1 (white). Each pixel corresponds to one node in the network. Let us load some images." ] }, { "cell_type": "code", "execution_count": 4, "id": "treated-paste", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img1 = plt.imread('source/AlbertGator.png')\n", "img2 = plt.imread('source/AlbertEinstein.png')\n", "img_list = [img1, img2]\n", "M = len(img_list)\n", "\n", "fig, ax = plt.subplots(1,M, figsize=(2*M,2))\n", "for m in range(M):\n", " ax[m].imshow(img_list[m], cmap='Greys_r') # pixel values: 0 = black, 1 = white\n", " ax[m].axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "surprised-battery", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "We have to convert the 2D images to 1D vectors, which will be the patterns to be stored. This is done simply by flattening the array. We also have to convert the raw values 0 and 1 to the values of the nodes 1 or -1." ] }, { "cell_type": "code", "execution_count": 5, "id": "6d8cbb52-ddeb-4df6-a022-c62adff466a9", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "def img2vec(img):\n", " \"\"\"\n", " convert a 2D black & white image to a 1D vector of +/-1.\n", " input:\n", " img: 2d-array of binary numbers (0 for black and 1 for white);\n", " some grayscale images have an extra dimension which is redundant and will be removed.\n", " output:\n", " vec: 1d-array of +/-1 values.\n", " \"\"\"\n", " if img.ndim > 2: # grayscale image with redundant dimension\n", " img = img[:,:,0] # remove redundant dimension\n", " vec = np.reshape(img, -1) # flatten array\n", " vec = 1-2*vec # convert from 0/1 to +/-1\n", " return vec\n", "\n", "def vec2img(vec, shape):\n", " \"\"\"\n", " convert a 1D vector of +/-1 to a 2D black & white image.\n", " input:\n", " vec: 1d-array of +/-1 values.\n", " shape: 2-tuple, shape of the 2d-array.\n", " output:\n", " img: 2d-array of binary numbers (0 for black and 1 for white).\n", " \"\"\"\n", " vec = (1-vec)/2\n", " img = np.reshape(vec, shape)\n", " return img" ] }, { "cell_type": "markdown", "id": "a090ad5f-db3c-4533-bdcd-35763c364dce", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "When we have converted the images to patterns, we will put them in a 2D array, where each row is a pattern. Then we can create the Hopfield network using these patterns." ] }, { "cell_type": "code", "execution_count": 6, "id": "essential-thomas", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "shape = img1.shape # shape of image (shape of 2d-array)\n", "N = np.prod(shape) # size of network (length of 1d-vector)\n", "\n", "patterns = []\n", "for img in img_list:\n", " vec = img2vec(img) # convert image to vector\n", " patterns.append(vec) # add vector as pattern\n", "\n", "net = Hopfield(patterns) # create Hopfield network" ] }, { "cell_type": "markdown", "id": "instrumental-relief", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "To retrieve a pattern, let us provide a seed, which can be a small section of the image we want to retrieve." ] }, { "cell_type": "code", "execution_count": 7, "id": "welsh-portugal", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAAB7CAYAAABUx/9/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAF40lEQVR4nO3dTUgUfxzH8c8+jLmy5CYpGwYZC2InUQ8bHcSiIDLoYllQxzDxkCfx3E0QJRK9hIqHpYgOYRIRFgpBhof1kInRBrHiA+EquK7tw0yH//8/aLs7rv5Xd9zv5wULjvPgD9+Os/MAa9E0DSSDNdcDoMPD2IIwtiCMLQhjC8LYgtiNZg4MDPC87IhpbW21pJvHPVsQxhaEsQVhbEEYWxDGFoSxBWFsQRhbEMYWhLEFYWxBGFsQxhaEsQVhbEEYWxDGFoSxBWFsQRhbEMYWhLEFYWxBGFsQxhaEsQVhbEEYWxDGFoSxBWFsQRhbEMYWhLEFYWxBGFsQxhaEsQVhbEEYWxDGFoSxBWFsQRhbEMYWhLEFYWxBGFsQxhaEsQVhbEEYWxDGFoSxBWFsQRhbEMYWxPDjGfOBzWZL+f1EIoFEIpHR+pqmQVVVw+Xsdjs0zdyfZpn3se/fvw+LJfnjKcfHx3H58uVd15+YmIDf78fDhw8NlwuFQpiamkIgENj3WA9a3sdO59KlSzv2bFVVoShK0nIXL1403GPj8TgsFgusVvMfEc0/wgPyXyCr1YpoNIrS0tKUy6mqmhS7rKwMkUgEkUgENpvtSIQGBO/ZABAIBNDY2IhYLIa1tbWM1/v16xdqamr06ZmZGRQUFBzACLNLbGy/34+WlhbMzc3t+P78/Dxqa2uxsbGRdt3i4mK8fv1an7bbj8av8WiMMotGR0fx7t07zM/P4/Pnz0nzPR4PFEXBs2fPcOrUKQDA8PAwhoaGAABnzpzB8+fP4fF4DnXc2SAi9tOnT7G5uQkA8Pl8mJqa2jG/qKgIvb29aGlpAQD09fXhxo0bKCwsBAC43W5cuHBB/9rr9errPnnyRD+mm/34bTF6pzkwMGDuE8cMPHjwAKWlpbh58yZGR0exsLCQtMzJkyextLQEu92OsbEx2Gw2NDQ04NixYym3GY/HMT4+Dk3TcO3aNbS3t0NRFLjdbv0PJFdaW1uTzzP/lfd79pcvXxCNRtHU1AS/358ydiwWw+zsLJqbm9HY2AgAWF1dTYodiUQQCASwtbWl/1u/desWurq6oCgK3r59y/PsXKquroaqqoYXUNbX11FbW4tgMIgXL15AVVUsLCwgGo3uWO7bt29ob2+H2+3e8QbtqMj72JmKx+Nwu936dENDA9bX1xGPxwEADocD9+7dw/T0dK6G+L+Z991Ejn39+hVXrlzRp9va2tDf369fU8/02rqZMHYaZWVlePPmjT7d3d0Nu92e9DL7zY/tGDtLzp07hw8fPuR6GIbExR4bG0N3d3fWtnfixAlEo1HMzMygvr4+a9s9COJiFxUV6efCV69eRSgUSnoVFxdnvL0fP35AURQUFBSkvXduFuJit7W1wWq14vHjx3A6nXC5XEkvq9WKT58+4fz587tuz+VypbxfbkbiYs/OzmJwcBCvXr0yXK6yshJDQ0O4fv36IY3s4ImLDQDT09N4//59ynk9PT3o7e2Fw+FAVVUVHj16hKampkMe4cHgRZW/dHR0YGNjQz+u19TUoLOzE1VVVfj58ydGRkZSrjc5OYlgMHiYQ90zxs5AXV0d6urqEAwGEYvFEIvF8PLlS9y+fRvAP4eGO3fu4O7du6ioqMjtYA0w9l+8Xm/a25SnT5+Gz+fD1tYWlpeX4fP5sLS0BK/Xa/iwg1kw9jbhcBgfP37cdbnCwkJMTk4iHA6jvLx818eMzULkG7RUEokEnE5nxuF+//69p+XNgLH3YXNzM+cPKewHY+/RysoKnE5nroexL4y9B3Nzczh79uyRutO1Xd4/g7a2tpb2uOpwOFBeXg4A0DQN379/h8fjSXv5MxwOY3FxMe3POn78eM4fKxb9DJrL5TKcv7q6qn9dUlKCUChkuHxJSUk2hpUT/DcuCGMLwtiCMLYgjC0IYwvC2IIwtiCMLQhjC8LYgjC2IIwtCGMLwtiCMLYgjC0IYwvC2IIwtiCGT5dSfuGeLQhjC8LYgjC2IIwtCGML8geOrMk9HNsdegAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "section = np.zeros(shape) + 0.5 # default value 0.5 for gray\n", "section[55:80,45:70] = img1[55:80,45:70] # choose small section of image\n", "seed = img2vec(section) # transform section of image to seed pattern\n", "net.set_network(state=seed) # set network state to the seed pattern\n", "\n", "plt.figure(figsize=(2,2))\n", "plt.imshow(section, cmap='Greys_r')\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "formed-theater", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Now we are ready to run the network. To speed things up, we will update 1000 nodes at a time. We might want to record the state of the network after every update. But to see how the network approaches one of the patterns, we only need to record the overlap between the network state and each pattern." ] }, { "cell_type": "code", "execution_count": 8, "id": "therapeutic-leadership", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "T = 50\n", "overlap_hist = [net.overlap()]\n", "for t in range(T):\n", " flip = net.updateN(size=1000)\n", " overlap_hist.append(net.overlap())" ] }, { "cell_type": "markdown", "id": "497f681f-71f8-4c87-9c85-92fdae359429", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Let us visualize the final state of the network." ] }, { "cell_type": "code", "execution_count": 9, "id": "06fe2ebf-cece-4156-907d-927de47b729a", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img = vec2img(net.state, shape) # convert vector to image\n", "plt.figure(figsize=(2,2))\n", "plt.imshow(img, cmap='Greys_r')\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "italic-century", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "We have successfully retrieved the full image!" ] }, { "cell_type": "markdown", "id": "b5d82400-5422-4252-a5ef-a1f73c9556fc", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Let us plot how the overlaps have changed over time." ] }, { "cell_type": "code", "execution_count": 10, "id": "clear-archives", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "overlap_hist = np.asarray(overlap_hist)\n", "plt.figure()\n", "for mu in range(M):\n", " plt.plot(overlap_hist[:,mu], label=r'$m^{%d}$' % (mu+1))\n", "plt.xlim(xmin=0)\n", "plt.xlabel(r'time $t$ / (1000 updates)')\n", "plt.ylabel(r'overlap $m^{\\mu}$')\n", "plt.legend(loc='upper left')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "worse-morocco", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "It can be seen that the network moved quickly towards pattern 1, i.e., retrieving the first image. This is expected because the seed that we provided at the beginning was part of this very image." ] }, { "cell_type": "markdown", "id": "brilliant-hanging", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "Since our patterns are images, it may be fun to watch how the image is changing over time as the network runs. So let us record the retrieval process as a movie, which can be done as follows." ] }, { "cell_type": "code", "execution_count": 11, "id": "fc260a46-cbb7-40c3-8e7c-055ae89780c8", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.animation as anim\n", "\n", "plt.rcParams[\"animation.html\"] = \"jshtml\"\n", "fig, ax = plt.subplots(figsize=(2,2))\n", "ax.axis('off')\n", "\n", "net.set_network(state=seed)\n", "frame = vec2img(net.state, shape)\n", "img = ax.imshow(frame, cmap='Greys_r', interpolation='none')\n", "\n", "def animate(t):\n", " if t > 0:\n", " flip = net.updateN(size=1000)\n", " frame = vec2img(net.state, shape)\n", " img.set_data(frame)\n", "\n", "mov = anim.FuncAnimation(fig, animate, interval=100)\n", "# mov.save('figures/retrieve_img1.mp4', fps=10, extra_args=['-vcodec', 'libx264'])\n", "plt.close()" ] }, { "cell_type": "code", "execution_count": 12, "id": "8b25e6db-565b-430d-b976-ba879143f03c", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mov" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }