6.1 Episodic Memory#
In the previous chapters, learning was based on gradually adjusting weights in a neural network. However, humans have the ability to learn from a single experience without the need for repetition. One could argue that this “one-shot” learning can be achieved by a high learning rate. However, a high learning rate can lead to catastrophic forgetting, where the network forgets previously learned associations.
Consider for example, the Rumelhart Semantic Network from the previous chapter. If we first train the network to associate various birds with flying, and then train the network with a single example of a penguin with a very high learning rate, the network will forget the previous association with birds and flying. In general, we can mitigate this problem by using interleaved training, where we mix examples from the bird-category with the penguin example. However, this doesn’t reflect human learning, where we can learn from a single example without forgetting previous associations.
McClelland et al, 1995 proposed two complemantary learning systems: A slow learning system that learns gradually from repetition (in form of weight adjustments in the Neocortex) and a fast learning system that learns from single experiences (in form of episodic memory in the hippocampus). Here, we will explore how such a episodic memory system can be modeled in PsyNeuLink.
Installation and Setup
If the following code fails, you might have to restart the kernel/session and run it again. This is a known issue when installing PsyNeulink in google colab.
import random
from psyneulink.library.models.Cohen_Huston1994 import threshold
%%capture
%pip install psyneulink
%pip install torch
import psyneulink as pnl
from torch import nn
import torch
import matplotlib.pyplot as plt
usage: ipykernel_launcher.py [-h] [--no-plot] [--threshold THRESHOLD]
[--settle-trials SETTLE_TRIALS]
ipykernel_launcher.py: error: unrecognized arguments: -f /tmp/tmpacson3rl.json --HistoryManager.hist_file=:memory:
An exception has occurred, use %tb to see the full traceback.
SystemExit: 2
/opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3557: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
Episodic Memory - Torch Implementation
Here, we implement a simple episodic memory module using PyTorch. The module stores a set of keys and values. When queried with a key, it returns a weighted sum of the values, where the weights are determined by the similarity between the query key and the stored key.
def safe_softmax(t, threshold=0.001, **kwargs):
"""
Softmax function that masks out values below a threshold.
"""
v = t
# Apply mask: only include values greater than mask_threshold
v = torch.where(abs(v) > threshold, v, torch.tensor(-torch.inf, device=v.device))
# Shift by the global max to avoid extreme values (numerical stability
v = v - torch.max(v)
# Exponentiate
masked_exp = torch.exp(v)
# Normalize (to sum to 1)
if not masked_exp.any():
return masked_exp
else:
return masked_exp / torch.sum(masked_exp)
class EMModule(nn.Module):
"""
The EM module is a key-value memory that stores a set of keys and values.
When queried with a key, it returns a weighted sum of the values, where the weights
are determined by the similarity between the query key and the stored keys.
"""
def __init__(self) -> None:
super().__init__()
# Current index to store the next key-value pair to
self.index = 0
# Keys and values to store
self.keys = None
self.values = None
def get_match_weights(self, key: torch.tensor) -> torch.tensor:
"""
Get the matched weights between the provided key and the stored keys, using the dot product.
"""
return torch.einsum('b a, c a -> c b', self.keys, key)
def forward(self, key: torch.tensor) -> torch.tensor:
"""
Get the weighted sum of the stored values, using the provided key.
"""
matched_weights = self.get_match_weights(key)
return torch.einsum('a b, c a -> c b', self.values, safe_softmax(matched_weights, dim=-1))
def prep(self, key, value):
"""
Prepare the memory by setting the keys and values.
"""
if self.keys is None:
self.keys = key
else:
self.keys = torch.cat((self.keys, key), dim=0)
if self.values is None:
self.values = value
else:
self.values = torch.cat((self.values, value), dim=0)
def reset(self):
"""
Reset the memory by setting the keys and values to None.
"""
self.keys = None
self.values = None
self.index = 0
def write(self, key, value):
"""
Write a key-value pair to the memory.
"""
if self.keys is None:
self.keys = key
else:
self.keys[self.index] = key
if self.values is None:
self.values = value
else:
self.values[self.index] = value
self.index += 1
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[2], line 24
20 else:
21 return masked_exp / torch.sum(masked_exp)
---> 24 class EMModule(nn.Module):
25 """
26 The EM module is a key-value memory that stores a set of keys and values.
27 When queried with a key, it returns a weighted sum of the values, where the weights
28 are determined by the similarity between the query key and the stored keys.
29 """
31 def __init__(self) -> None:
NameError: name 'nn' is not defined
Let’s inspect the EMModule class to understand how it works:
em = EMModule()
def plot(em_module):
# Plot keys
if em.keys is not None:
plt.figure(figsize=(8, 4))
plt.imshow(em_module.keys.detach().numpy(), cmap='viridis', aspect='auto')
plt.colorbar(label="Key Value")
plt.title("Keys Heatmap")
plt.xlabel("Features")
plt.ylabel("Memory Slots")
plt.show()
else:
print(
"No keys found in the memory. Please add keys before plotting.")
# Plot values
if em_module.values is not None:
plt.figure(figsize=(8, 4))
plt.imshow(em_module.values.detach().numpy(), cmap='viridis', aspect='auto')
plt.colorbar(label="Value")
plt.title("Values Heatmap")
plt.xlabel("Features")
plt.ylabel("Memory Slots")
plt.show()
else:
print(
"No values found in the memory. Please add values before plotting.")
plot(em)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[3], line 1
----> 1 em = EMModule()
3 def plot(em_module):
4 # Plot keys
5 if em.keys is not None:
NameError: name 'EMModule' is not defined
Let’s start by adding some empty key and value pairs to the memory. Here, we assume both keys and values are one-hot encoded (categorical) vectors. Here, we use 5-dimensional vetors and initialize the memory with 10 key-value pairs filled with 0.001:
memory_capacity = 10
em.reset()
for _ in range(memory_capacity):
em.prep(torch.tensor([[.01] * 5], dtype=torch.float),
torch.tensor([[.01] * 5], dtype=torch.float))
plot(em)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[4], line 3
1 memory_capacity = 10
----> 3 em.reset()
5 for _ in range(memory_capacity):
6 em.prep(torch.tensor([[.01] * 5], dtype=torch.float),
7 torch.tensor([[.01] * 5], dtype=torch.float))
NameError: name 'em' is not defined
All the keys and values are initialized with 0.001. Let’s add a key-value pair by adding it to the memory:
em.write(torch.tensor([1, 0, 0, 0, 0], dtype=torch.float), torch.tensor([0, 1, 0, 0, 0], dtype=torch.float))
plot(em)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[5], line 1
----> 1 em.write(torch.tensor([1, 0, 0, 0, 0], dtype=torch.float), torch.tensor([0, 1, 0, 0, 0], dtype=torch.float))
2 plot(em)
NameError: name 'em' is not defined
We can see, that for both the keys and the values, the first memory slot has been updated. Let’s see what happens when we query the memory:
res = em.forward(torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float))
print(res)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[6], line 1
----> 1 res = em.forward(torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float))
2 print(res)
NameError: name 'em' is not defined
Note, while the returned value has the highest value at the second entry, it is not the exact value we stored. Let’s explore what is going on. The memory first matches the query key with the stored keys. But how exactly does it do that? Let’s inspect the match weights:
memorized_keys = em.keys
query_key = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float)
match_weights = torch.einsum('b a, c a -> c b', memorized_keys, query_key)
print(match_weights)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[7], line 1
----> 1 memorized_keys = em.keys
2 query_key = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.float)
4 match_weights = torch.einsum('b a, c a -> c b', memorized_keys, query_key)
NameError: name 'em' is not defined
torch.einsum('b a, c a -> c b', memorized_keys, query_key)
calculates the dot product between the query key and each stored key. The result is a 1x10 vector that indicates the “overlap” between the query key and each stored key. Since the first entry of keys is the same as the query key (both [1, 0, 0, 0, 0]
the dot product is 1.0. The other entries have a dot product of 0.001
since the memory was initialized with this value.
In the next step, we calculate the softmax of the math weight. This normalizes the match weights to sum to 1.0.
match_weights_sm = safe_softmax(match_weights, dim=-1)
print(match_weights_sm)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[8], line 1
----> 1 match_weights_sm = safe_softmax(match_weights, dim=-1)
2 print(match_weights_sm)
NameError: name 'match_weights' is not defined
Here, we see that the softmax “flattens” the vector while keeping the highest value at the first entry.
To retreive the memory, we use another einsum by multiplying the stored memory values (em.values
) with the match_weights. This results in a weighted sum of the stored values, where the weights are determined by the similarity between the query key and the stored key for each memory slot:
retrieved_value = torch.einsum('a b, c a -> c b', em.values, match_weights_sm)
print(retrieved_value)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[9], line 1
----> 1 retrieved_value = torch.einsum('a b, c a -> c b', em.values, match_weights_sm)
2 print(retrieved_value)
NameError: name 'torch' is not defined
The above implementation of the softmax masks values below a certain threshold. Why is this necessary? What would happen if we didn’t mask values?
💡 Hint
In our toy example we have only a few memory slots. However, in a most scenarios, we want to model a large number of episodic memory slots. In this case, the match_weight
vector will have a large number of entries with a lot of zeros (or near zeros). Think about why that is problematic.
✅ Solution 1
The “flattening” effect of the softmax function is dependent on the length of the vector. For example, try running the following code:
res = safe_softmax(torch.tensor([1] + [0.01] * 10), threshold=0.001 )
res_2 = safe_softmax(torch.tensor([1] + [0.01] * 100), threshold=0.001)
print(res[0])
print(res_2[0])
res_safe_1 = safe_softmax(torch.tensor([1] + [0.01] * 10), threshold=0.01)
res_safe_2 = safe_softmax(torch.tensor([1] + [0.01] * 100), threshold=0.01)
print(res_safe_1[0])
print(res_safe_2[0])
Try playing around with the threshold value in the safe_softmax
function. Can you find a threshold value that results in memory retrieval that is closer to the stored value?
Try adding more key-value pairs to the memory. How does the memory retrieval change with more key-value pairs?
PsyNeuLink - EMComposition#
PsyNeuLink
provides build-in support for episodic memory through the EMComposition
. Here, we explain the most important parameters of the EMComposition
. It can be used to easily build more complex memory structures, and we will use it in the next tutorial in the EGO model.
em = pnl.EMComposition(name='EM', # name
memory_capacity=1000, # number of key-value pairs
memory_template=[[0, 0], [0, 0, 0, 0], [0, 0, 0]],
# template for the memory. Note: Here we use 3 memory slots (instead of just a key value pair, we can store as many keys and pairs as we want.)
fields={'1':
{pnl.FIELD_WEIGHT: .33,
# weight of the key. This determines how much this "slot" influences the retrieval
pnl.LEARN_FIELD_WEIGHT: False, # The weight can be learned via backpropagation
pnl.TARGET_FIELD: False
# If this is a target field, the error is calculated here, and backpropagated
},
'2': {pnl.FIELD_WEIGHT: .33,
pnl.LEARN_FIELD_WEIGHT: False,
pnl.TARGET_FIELD: False},
'3': {pnl.FIELD_WEIGHT: .33,
pnl.LEARN_FIELD_WEIGHT: False,
pnl.TARGET_FIELD: False},
},
memory_fill=.001, # fill the memory with this value
normalize_memories=True, # normalize the memories
softmax_gain=1., # gain of the softmax function
softmax_threshold=0.1, # threshold of the softmax function
memory_decay=0, # memory can be decayed over time
)
em.show_graph(output_fmt='jupyter')
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[10], line 1
----> 1 em = pnl.EMComposition(name='EM', # name
2 memory_capacity=1000, # number of key-value pairs
3 memory_template=[[0, 0], [0, 0, 0, 0], [0, 0, 0]],
4 # template for the memory. Note: Here we use 3 memory slots (instead of just a key value pair, we can store as many keys and pairs as we want.)
5 fields={'1':
6 {pnl.FIELD_WEIGHT: .33,
7 # weight of the key. This determines how much this "slot" influences the retrieval
8 pnl.LEARN_FIELD_WEIGHT: False, # The weight can be learned via backpropagation
9 pnl.TARGET_FIELD: False
10 # If this is a target field, the error is calculated here, and backpropagated
11 },
12 '2': {pnl.FIELD_WEIGHT: .33,
13 pnl.LEARN_FIELD_WEIGHT: False,
14 pnl.TARGET_FIELD: False},
15 '3': {pnl.FIELD_WEIGHT: .33,
16 pnl.LEARN_FIELD_WEIGHT: False,
17 pnl.TARGET_FIELD: False},
18
19 },
20 memory_fill=.001, # fill the memory with this value
21 normalize_memories=True, # normalize the memories
22 softmax_gain=1., # gain of the softmax function
23 softmax_threshold=0.1, # threshold of the softmax function
24 memory_decay=0, # memory can be decayed over time
25 )
26 em.show_graph(output_fmt='jupyter')
NameError: name 'pnl' is not defined
The above figure seems complicated at first, but it follows the same principle as the torch implementation: We look it from the bottom to the top:
The arrows from the 1, 2, and 3 query to the “STORE” node, represent that these values are stored in memory
All of them ara also passed through a “MATCH” node, wich calculates the similarity between the query and the stored values (just as desibed above for the keys)
The “MATCH” nodes are then weighted and combined. (Here they are also softmaxed)
Then the result is used to retrieve the memory by multiplying the “combined matchse” with the stored values.
In our implementation, we specified input node 1 as having 2 entries (1x2 vector), the input node 2 with 4 entries (1x4 vector), and the input node 3 with 3 entries (1x3 vector). Yet, In the explanation above, I talked about adding weighted vectors together. How can that be?
💡 Hint
We are not adding the query vectors together but the matched weights. What is the shape of these weights?
✅ Solution 4
The matched weights have the shape of the number of memory slots. Their entries don’t represent the query vectors themselves, the i
-th entry signifies how similar the memory in slot i
is to the query vector.
This is why a weighted sum makes sense, we are literally weighing how similar 1, 2, and 3 is and then combining them. This way the retrieavel searches for the most combined (weighted) similarity
Marking a Field as Value (non-query)#
If we don’t want specific fields to be taken into account on retrieval (for example if they are the “target” fields that the model is supposed to predict), we can set their retrieval weight to “None”:
em.memory
em = pnl.EMComposition(name='EM with Target', # name
memory_capacity=1000, # number of key-value pairs
memory_template=[[0, 0], [0, 0, 0, 0], [0, 0, 0]],
# template for the memory. Note: Here we use 3 memory slots (instead of just a key value pair, we can store as many keys and pairs as we want.)
fields={'1':
{pnl.FIELD_WEIGHT: .5,
# weight of the key. This determines how much this "slot" influences the retrieval
pnl.LEARN_FIELD_WEIGHT: False, # The weight can be learned via backpropagation
pnl.TARGET_FIELD: False
# If this is a target field, the error is calculated here, and backpropagated
},
'2': {pnl.FIELD_WEIGHT: .5,
pnl.LEARN_FIELD_WEIGHT: False,
pnl.TARGET_FIELD: False},
'3': {pnl.FIELD_WEIGHT: None,
pnl.LEARN_FIELD_WEIGHT: False,
pnl.TARGET_FIELD: True},
},
memory_fill=.001, # fill the memory with this value
normalize_memories=True, # normalize the memories
softmax_gain=1., # gain of the softmax function
softmax_threshold=0.1, # threshold of the softmax function
memory_decay=0, # memory can be decayed over time
)
em.show_graph(output_fmt='jupyter')
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[11], line 1
----> 1 em = pnl.EMComposition(name='EM with Target', # name
2 memory_capacity=1000, # number of key-value pairs
3 memory_template=[[0, 0], [0, 0, 0, 0], [0, 0, 0]],
4 # template for the memory. Note: Here we use 3 memory slots (instead of just a key value pair, we can store as many keys and pairs as we want.)
5 fields={'1':
6 {pnl.FIELD_WEIGHT: .5,
7 # weight of the key. This determines how much this "slot" influences the retrieval
8 pnl.LEARN_FIELD_WEIGHT: False, # The weight can be learned via backpropagation
9 pnl.TARGET_FIELD: False
10 # If this is a target field, the error is calculated here, and backpropagated
11 },
12 '2': {pnl.FIELD_WEIGHT: .5,
13 pnl.LEARN_FIELD_WEIGHT: False,
14 pnl.TARGET_FIELD: False},
15 '3': {pnl.FIELD_WEIGHT: None,
16 pnl.LEARN_FIELD_WEIGHT: False,
17 pnl.TARGET_FIELD: True},
18
19 },
20 memory_fill=.001, # fill the memory with this value
21 normalize_memories=True, # normalize the memories
22 softmax_gain=1., # gain of the softmax function
23 softmax_threshold=0.1, # threshold of the softmax function
24 memory_decay=0, # memory can be decayed over time
25 )
26 em.show_graph(output_fmt='jupyter')
NameError: name 'pnl' is not defined
As you see, this way 3 is stored (and retrieved) but is not taken into account when calculating the matched similarity (it is not “used” to retrieve from memory).