6.1 Episodic Memory#

In the previous chapters, learning was based on gradually adjusting weights in a neural network. However, humans have the ability to learn from a single experience without the need for repetition. One could argue that this “one-shot” learning can be achieved by a high learning rate. However, a high learning rate can lead to catastrophic forgetting, where the network forgets previously learned associations.

Consider for example, the Rumelhart Semantic Network from the previous chapter. If we first train the network to associate various birds with flying, and then train the network with a single example of a penguin with a very high learning rate, the network will forget the previous association with birds and flying. In general, we can mitigate this problem by using interleaved training, where we mix examples from the bird-category with the penguin example. However, this doesn’t reflect human learning, where we can learn from a single example without forgetting previous associations.

McClelland et al, 1995 proposed two complemantary learning systems: A slow learning system that learns gradually from repetition (in form of weight adjustments in the Neocortex) and a fast learning system that learns from single experiences (in form of episodic memory in the hippocampus). Here, we will explore how such a episodic memory system can be modeled in PsyNeuLink.

Installation and Setup

If the following code fails, you might have to restart the kernel/session and run it again. This is a known issue when installing PsyNeulink in google colab.

import random

from psyneulink.library.models.Cohen_Huston1994 import threshold
%%capture
%pip install psyneulink
%pip install torch

import psyneulink as pnl
from torch import nn
import torch
import matplotlib.pyplot as plt
usage: ipykernel_launcher.py [-h] [--no-plot] [--threshold THRESHOLD]
                             [--settle-trials SETTLE_TRIALS]
ipykernel_launcher.py: error: unrecognized arguments: -f /tmp/tmpacson3rl.json --HistoryManager.hist_file=:memory:
An exception has occurred, use %tb to see the full traceback.

SystemExit: 2
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3557: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)

Episodic Memory - Torch Implementation

Here, we implement a simple episodic memory module using PyTorch. The module stores a set of keys and values. When queried with a key, it returns a weighted sum of the values, where the weights are determined by the similarity between the query key and the stored key.

def safe_softmax(t, threshold=0.001, **kwargs):
    """
    Softmax function that masks out values below a threshold.
    """

    v = t

    # Apply mask: only include values greater than mask_threshold
    v = torch.where(abs(v) > threshold, v, torch.tensor(-torch.inf, device=v.device))

    # Shift by the global max to avoid extreme values (numerical stability
    v = v - torch.max(v)

    # Exponentiate
    masked_exp = torch.exp(v)

    # Normalize (to sum to 1)
    if not masked_exp.any():
        return masked_exp
    else:
        return masked_exp / torch.sum(masked_exp)


class EMModule(nn.Module):
    """
    The EM module is a key-value memory that stores a set of keys and values.
    When queried with a key, it returns a weighted sum of the values, where the weights
    are determined by the similarity between the query key and the stored keys.
    """

    def __init__(self) -> None:
        super().__init__()
        # Current index to store the next key-value pair to
        self.index = 0

        # Keys and values to store
        self.keys = None
        self.values = None

    def get_match_weights(self, key: torch.tensor) -> torch.tensor:
        """
        Get the matched weights between the provided key and the stored keys, using the dot product.
        """
        return torch.einsum('b a, c a -> c b', self.keys, key)

    def forward(self, key: torch.tensor) -> torch.tensor:
        """
        Get the weighted sum of the stored values, using the provided key.
        """
        matched_weights = self.get_match_weights(key)
        return torch.einsum('a b, c a -> c b', self.values, safe_softmax(matched_weights, dim=-1))

    def prep(self, key, value):
        """
        Prepare the memory by setting the keys and values.
        """
        if self.keys is None:
            self.keys = key
        else:
            self.keys = torch.cat((self.keys, key), dim=0)
        if self.values is None:
            self.values = value
        else:
            self.values = torch.cat((self.values, value), dim=0)

    def reset(self):
        """
        Reset the memory by setting the keys and values to None.
        """
        self.keys = None
        self.values = None
        self.index = 0

    def write(self, key, value):
        """
        Write a key-value pair to the memory.
        """
        if self.keys is None:
            self.keys = key
        else:
            self.keys[self.index] = key
        if self.values is None:
            self.values = value
        else:
            self.values[self.index] = value
        self.index += 1
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[2], line 24
     20     else:
     21         return masked_exp / torch.sum(masked_exp)
---> 24 class EMModule(nn.Module):
     25     """
     26     The EM module is a key-value memory that stores a set of keys and values.
     27     When queried with a key, it returns a weighted sum of the values, where the weights
     28     are determined by the similarity between the query key and the stored keys.
     29     """
     31     def __init__(self) -> None:

NameError: name 'nn' is not defined

Let’s inspect the EMModule class to understand how it works:

em = EMModule()

def plot(em_module):
        # Plot keys
    if em.keys is not None:
        plt.figure(figsize=(8, 4))
        plt.imshow(em_module.keys.detach().numpy(), cmap='viridis', aspect='auto')
        plt.colorbar(label="Key Value")
        plt.title("Keys Heatmap")
        plt.xlabel("Features")
        plt.ylabel("Memory Slots")
        plt.show()
    else:
        print(
            "No keys found in the memory. Please add keys before plotting.")

    # Plot values
    if em_module.values is not None:
        plt.figure(figsize=(8, 4))
        plt.imshow(em_module.values.detach().numpy(), cmap='viridis', aspect='auto')
        plt.colorbar(label="Value")
        plt.title("Values Heatmap")
        plt.xlabel("Features")
        plt.ylabel("Memory Slots")
        plt.show()
    else:
        print(
            "No values found in the memory. Please add values before plotting.")


plot(em)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 em = EMModule()
      3 def plot(em_module):
      4         # Plot keys
      5     if em.keys is not None:

NameError: name 'EMModule' is not defined

Let’s start by adding some empty key and value pairs to the memory. Here, we assume both keys and values are one-hot encoded (categorical) vectors. Here, we use 5-dimensional vetors and initialize the memory with 10 key-value pairs filled with 0.001:

memory_capacity = 10

em.reset()

for _ in range(memory_capacity):
    em.prep(torch.tensor([[.01] * 5], dtype=torch.float),
            torch.tensor([[.01] * 5], dtype=torch.float))

plot(em)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[4], line 3
      1 memory_capacity = 10
----> 3 em.reset()
      5 for _ in range(memory_capacity):
      6     em.prep(torch.tensor([[.01] * 5], dtype=torch.float),
      7             torch.tensor([[.01] * 5], dtype=torch.float))

NameError: name 'em' is not defined

All the keys and values are initialized with 0.001. Let’s add a key-value pair by adding it to the memory:

em.write(torch.tensor([1, 0, 0, 0, 0], dtype=torch.float), torch.tensor([0, 1, 0, 0, 0], dtype=torch.float))
plot(em)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 em.write(torch.tensor([1, 0, 0, 0, 0], dtype=torch.float), torch.tensor([0, 1, 0, 0, 0], dtype=torch.float))
      2 plot(em)

NameError: name 'em' is not defined

We can see, that for both the keys and the values, the first memory slot has been updated. Let’s see what happens when we query the memory:

res = em.forward(torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float))
print(res)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 res = em.forward(torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float))
      2 print(res)

NameError: name 'em' is not defined

Note, while the returned value has the highest value at the second entry, it is not the exact value we stored. Let’s explore what is going on. The memory first matches the query key with the stored keys. But how exactly does it do that? Let’s inspect the match weights:

memorized_keys = em.keys
query_key = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float)

match_weights = torch.einsum('b a, c a -> c b', memorized_keys, query_key)
print(match_weights)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 memorized_keys = em.keys
      2 query_key = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float)
      4 match_weights = torch.einsum('b a, c a -> c b', memorized_keys, query_key)

NameError: name 'em' is not defined

torch.einsum('b a, c a -> c b', memorized_keys, query_key) calculates the dot product between the query key and each stored key. The result is a 1x10 vector that indicates the “overlap” between the query key and each stored key. Since the first entry of keys is the same as the query key (both [1, 0, 0, 0, 0] the dot product is 1.0. The other entries have a dot product of 0.001 since the memory was initialized with this value.

In the next step, we calculate the softmax of the math weight. This normalizes the match weights to sum to 1.0.

match_weights_sm = safe_softmax(match_weights, dim=-1)
print(match_weights_sm)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[8], line 1
----> 1 match_weights_sm = safe_softmax(match_weights, dim=-1)
      2 print(match_weights_sm)

NameError: name 'match_weights' is not defined

Here, we see that the softmax “flattens” the vector while keeping the highest value at the first entry.

To retreive the memory, we use another einsum by multiplying the stored memory values (em.values) with the match_weights. This results in a weighted sum of the stored values, where the weights are determined by the similarity between the query key and the stored key for each memory slot:

retrieved_value = torch.einsum('a b, c a -> c b', em.values, match_weights_sm)
print(retrieved_value)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 retrieved_value = torch.einsum('a b, c a -> c b', em.values, match_weights_sm)
      2 print(retrieved_value)

NameError: name 'torch' is not defined

🎯 Exercise 1

The above implementation of the softmax masks values below a certain threshold. Why is this necessary? What would happen if we didn’t mask values?

💡 Hint

In our toy example we have only a few memory slots. However, in a most scenarios, we want to model a large number of episodic memory slots. In this case, the match_weight vector will have a large number of entries with a lot of zeros (or near zeros). Think about why that is problematic.

✅ Solution 1

The “flattening” effect of the softmax function is dependent on the length of the vector. For example, try running the following code:

res = safe_softmax(torch.tensor([1] + [0.01] * 10), threshold=0.001 )
res_2 = safe_softmax(torch.tensor([1] + [0.01] * 100), threshold=0.001)
print(res[0])
print(res_2[0])

res_safe_1 = safe_softmax(torch.tensor([1] + [0.01] * 10), threshold=0.01)
res_safe_2 = safe_softmax(torch.tensor([1] + [0.01] * 100), threshold=0.01)
print(res_safe_1[0])
print(res_safe_2[0])

🎯 Exercise 2

Try playing around with the threshold value in the safe_softmax function. Can you find a threshold value that results in memory retrieval that is closer to the stored value?

🎯 Exercise 3

Try adding more key-value pairs to the memory. How does the memory retrieval change with more key-value pairs?