6.2 Episodic Generalization Optimization - EGO#

Introduction#

Human cognition is unique in its ability to perform a wide range of tasks and to learn new tasks quickly. Both abilities have long been associated with the acquisition of knowledge that can generalize across tasks and the flexible use of that knowledge to execute goal-directed behavior. In this tutorial, we introduce how this can emerge in a neural network by implementing the Episodic Generalization and Optimization (EGO) framework. The framework consists of an episodic memory module, which rapidly learns relationships between stimuli; a semantic pathway, which more slowly learns how stimuli map to responses; and a recurrent context module, which maintains a representation of task-relevant context information, integrates this over time, and uses it to recall context-relevant memories.

EGO

The EGO framework consists of a control mechanism (context module; upper middle) and an episodic memory mechanism (bottom left). Episodic memory records conjunctions of stimuli (blue boxes), contexts (pink boxes), and observed responses (green boxes) at each time point (rows). Bidirectional arrows connect episodic memory to the stimulus, context, and output, indicating that these values can be stored in or used to query episodic memory, or retrieved from it when another field is queried. You can think of this as a more flexible dictionary that stores triplets instead of distinct key-value pairs, and allows any field (or any combinations of fields) to act as a key. The context module integrates previous context (recurrent connection) along with information about the stimulus and the context retrieved from memory.

Here we show that the EGO framework can emulate human behavior in a specific learning environment where participants are trained on two sets of sequences involving identical states presented in different orders for different contexts. Empirical findings show that participants perform better when trained in blocks of each context than when trained interleaved:

Task: Coffe Shop World (CSW)#

suspicious barrista caffe graditude

Imagine, you are in a city with two coffee shops, each with a different layout and different ways of ordering. In one coffee shop—called The Suspicious Barista—you order first, pay for the coffee, and then sit down to wait until the waiter brings your order. In the other coffee shop—called Café Gratitude—you sit down first, wait until the waiter comes and takes your order. You pay after finishing the coffee.

This example demonstrates that many situations share similar stimuli but have different transition structures. Simple integration will help the system learn the transition structure, but it will only provide a weak cue about the difference between them due to the similarity between the situations. In other words the states –ordering, paying, and sitting down– are very similar between the two situations and are therefore hard to distinguish. This can be overcome by differentiating the context representations associated with each setting (e.g., learning different context representations for coffee shops with paranoid vs. gullible baristas). Recent empirical work suggests that people can learn how to do this very effectively, but that this depends on the temporal structure of the environment: people do better when trained in blocks of each situation than when trained interleaved (Beukers et al., 2023).

We start with creating a dataset for the CSW task.

Installation and Setup

If the following cell fails to execute, please restart the kernel (or session) and run the cell again. This is a known issue when running in google colab.

%%capture
%pip install psyneulink

import psyneulink as pnl
import random

Generating data for the CSW task#

We start by generating a dataset for the CSW task. The dataset consists of sequences of states. The task is to predict the next state given the current state and the context. The transition between states is determined by the context which in turn is determined by the “first” state in the sequence. The following figure illustrates the task structure:

EGO

On the left side of the figure, you can see the task structure:

The two colors represent different contexts: blue and orange.

  • If the first observed state in a sequence is 0, the participant is in the blue context.

    • The next state can be either 1 or 2.

    • From then on, transitions are deterministic:

      • 1 → 3 → 5 → 7

      • 2 → 4 → 6 → 8

  • If the first observed state is 9, the participant is in the orange context.

    • The sequence starts with either 1 or 2, but follows a different transition pattern:

      • 1 → 4 → 5 → 8

      • 2 → 3 → 6 → 7

The right side of the figure shows the different learning paradigms:

In the blocked paradigm, participants are trained on blocks of the same context. In the interleaved paradigm, participants are trained on a mix of contexts. In the test paradigm, participants are tested on a sequence of random contexts.

We start with defining a function that generates a context-specific sequence:

def gen_context(
    context: int,
    start_state: int,
):
    """
    Generate a context-specific sequence.
    Args:
        context (int): The context to generate the sequence for. (0 or 9)
        start_state (int): The first state in the sequence. (1 or 2)
    """
    seq = [context, start_state]
    if context == 0:
        for _ in range(3):
            seq.append(seq[-1] + 2)
    elif context == 9:
        for _ in range(3):
            seq.append(seq[-1] + 1 if seq[-1] % 2 == 0 else seq[-1] + 3)
    return seq

"""Test the function"""
assert gen_context(0, 1) == [0, 1, 3, 5, 7]
assert gen_context(9, 2) == [9, 2, 3, 6, 7]

Generate a full dataset for the CSW task. Now, let’s create a function that returns the full trial sequence for a given paradigm and number of samples.

# Define the paradigms
BLOCKED = 'blocked'
INTERLEAVED = 'interleaved'


def gen_context_sequences(
        paradigm: str,
        train_contexts: int,
        test_contexts: int,
        block_size: int = 4,
):
    """
    Generate a dataset for the CSW task.
    Args:
        paradigm (str): The paradigm to generate the dataset for. (blocked or interleaved)
        train_contexts (int): The number of training contexts.
        test_contexts (int): The number of test contexts.
        block_size (int): The size of each block in the blocked paradigm.
    """
    assert train_contexts % block_size == 0, "The number of training samples must be a multiple of block_size."
    x = []
    if paradigm == INTERLEAVED:
        for idx in range(train_contexts):
            if idx % 2: # odd contexts -> context 0
                x += [gen_context(0, random.randint(1, 2))]
            else: # even contexts -> context 9
                x += [gen_context(9, random.randint(1, 2))]

    if paradigm == BLOCKED:
        for i in range(block_size): # block_size number of blocks
            if i % 2: # odd blocks -> context 0
                for _ in range(train_contexts // block_size):
                    x += [gen_context(0, random.randint(1, 2))]
            else: # even blocks -> context 9
                for _ in range(train_contexts // block_size):
                    x += [gen_context(9, random.randint(1, 2))]

    for _ in range(test_contexts):
        x += [gen_context(random.choice([0, 9]), random.randint(1, 2))]
    return x


context_sequences = gen_context_sequences(BLOCKED, 8, 4)
context_sequences
[[9, 2, 3, 6, 7],
 [9, 1, 4, 5, 8],
 [0, 2, 4, 6, 8],
 [0, 2, 4, 6, 8],
 [9, 2, 3, 6, 7],
 [9, 2, 3, 6, 7],
 [0, 1, 3, 5, 7],
 [0, 2, 4, 6, 8],
 [9, 1, 4, 5, 8],
 [9, 2, 3, 6, 7],
 [0, 1, 3, 5, 7],
 [0, 2, 4, 6, 8]]

The structure of the generated sequence is not “realistic” yet. The participant doesn’t see distinct contexts but rather states. We need to “flatten” the sequence. Also, we instead of using integers to represent the states, we will use one-hot encoding:

def one_hot_encode(
        label: int,
        num_classes: int):
    """
    One hot encode a label (integer)
    Args:
        label (int): The label to encode (between 0 and num_classes-1)
        num_classes (int): The number of classes
    """
    return [1 if i == label else 0 for i in range(num_classes)]


def state_sequence(
        paradigm: str,
        train_trials: int,
        test_trials: int,
        context_length: int = 5,
        block_size: int = 4,
):
    """
    Generate a dataset for the CSW task.
    Args:
        paradigm (str): The paradigm to generate the dataset for. (blocked or interleaved)
        train_trials (int): The number of training trials.
        test_trials (int): The number of test trials.
        context_length (int): The length of the context.
        block_size (int): The size of each block in the blocked paradigm.
    """


    assert train_trials % context_length == 0, "The number of training samples must be a multiple of context_length."
    assert test_trials % context_length == 0, "The number of test samples must be a multiple of context_length."

    train_contexts = train_trials // context_length
    test_contexts = test_trials // context_length

    train_context_sequences = gen_context_sequences(
        paradigm, train_contexts, test_contexts, block_size
    )

    states = []
    for context_sequence in train_context_sequences:
        for state_int in context_sequence:
            states.append(one_hot_encode(state_int, 11))
    return states


state_sequences = state_sequence(BLOCKED, 20, 5)
state_sequences
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]]

🎯 Exercise 1

Why do we encode the states using one-hot encoding?

✅ Solution 1

One-hot encoding is used for categorical variables. This means states have no inherit “order” or can be compared using arithmetic operations in a meaningful way. One-hot encoding allows this representation as states are “orthogonal” to each other.

🎯 Exercise 2

We want to train the EGO model in a supervised manner but the generated dataset doesn’t allow us to do so. Why is this the case and what do we need to do be able to train the model?

💡 Hint 1

For supervised training, we need to provide a target for each input. Think about what the target should be in this case.

💡 Hint 2

The task in this case, is to predict the next state given the current state.

✅ Solution 2

The target in this case is just the next state in the sequence:

x = state_sequence(BLOCKED, 20, 5)
y = x[1:] + [one_hot_encode(0, 11)] # the last state has no next state and is arbitrary in this case either 0 or 9

The EGO model#

As mentioned earlier, the EGO model consists of three main components: an episodic memory module, a semantic pathway, and a recurrent context module. PsyNeulink provides a EMComposition class that allows us to create the episodic memory module. The EMcomposition class is a subclass of the Composition class. A strength of the PsyNeuLink framework is that it allows fo the creation of complex composition that can be used as mechanism in other compositions. Here, we first look at the EMComposition class in isolation and then integrate it into the EGO model.

Episodic Memory Module - EMComposition#

EM

Here, we initialize the EMComposition for the episodic memory shown above. The EMComposition allows for specifying the structure of the episodic memory. Remember, the task here is to predict the state from the previous state and the context. Therefore, in our case each entry in the memory consists of a triplet of states:

  • The current state (green box)

  • The previous state (blue box)

  • The context (pink box)

Each state is represented as a vector with 11 elements (one hot encoding).

Here, we also specify the specific fields. Fields have three main parameters that have to be specified as a dictionary:

  • FIELD_WEIGHT: The weight of the field when retrieving from memory

  • LEARN_FIELD_WEIGHT: Whether the retrieval field weight should be learned (Here, we won’t learn these weights but set them)

  • TARGET_FIELD: Whether the field is a target field (Meaning it’s “error” is calculated during learning)

🎯 Exercise 3

Before looking at the code below, think about what to set for the FIELD_WEIGHT and the TARGET_FIELD for the three different fields (current state, previous state, and context).

💡 Hint

The FIELD_WEIGHT specifies weather a field should be used during retrieval (and how much it should be used during retrieval). It is a scalar value between 0 and 1. The TARGET_FIELD specifies weather a field is a target field.

✅ Solution

The FIELD_WEIGHT for the current state should be None since it is the target field and shouldn’t be used in retrieval. The FIELD_WEIGHT for both the previous and the context should be set to an equal value (here we set them both to 1). The TARGET_FIELD should be set to True for the current state and False for the previous state and the context.

name = 'EM'  # a name for the EMComposition

# Memory parameters
state_size = 11  # the size of the state vector
memory_capacity = 1000  # here we set the maximum number of entries in the memory (we want to be able to store all 1000 trials)

# Fields

# State field
state_name = 'STATE'
state_retrieval_weight = None  # This entry is not used when retrieving from memory (remember, we want to predict the state)
state_is_target = True

# Previous state field
previous_state_name = 'PREVIOUS STATE'
previous_state_retrieval_weight = .5  # This entry is used when retrieving from memory
previous_state_is_target = False

# Context field
context_name = 'CONTEXT'
context_retrieval_weight = .5  # This entry is used when retrieving from memory
context_is_target = False

em = pnl.EMComposition(name=name,
                       memory_template=[[0] * state_size,  # state
                                        [0] * state_size,  # previous state
                                        [0] * state_size],  # context
                       memory_fill=.001,
                       memory_capacity=memory_capacity,
                       normalize_memories=False,
                       memory_decay_rate=0,  # no decay of memory
                       softmax_gain=10.,
                       softmax_threshold=.001,
                       fields={state_name: {pnl.FIELD_WEIGHT: state_retrieval_weight,
                                            pnl.LEARN_FIELD_WEIGHT: False,
                                            pnl.TARGET_FIELD: True},
                               previous_state_name: {pnl.FIELD_WEIGHT: previous_state_retrieval_weight,
                                                     pnl.LEARN_FIELD_WEIGHT: False,
                                                     pnl.TARGET_FIELD: False},
                               context_name: {pnl.FIELD_WEIGHT: context_retrieval_weight,
                                              pnl.LEARN_FIELD_WEIGHT: False,
                                              pnl.TARGET_FIELD: False}},

                       normalize_field_weights=True,

                       concatenate_queries=False,
                       enable_learning=True,
                       learning_rate=.5,
                       device=pnl.CPU
                       )
Unexpected exception formatting exception. Falling back to standard exception
Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3667, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_2427/919034945.py", line 24, in <module>
    em = pnl.EMComposition(name=name,
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/library/compositions/emcomposition.py", line 1758, in __init__
    self._construct_pathways(self.memory_template,
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/library/compositions/emcomposition.py", line 2277, in _construct_pathways
    self.add_linear_processing_pathway([self.combined_matches_node, self.softmax_node])
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py", line 742, in wrapper
    return func(*args, context=context, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py", line 8140, in add_linear_processing_pathway
    self._analyze_graph(context)
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py", line 751, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py", line 4312, in _analyze_graph
    self._create_CIM_ports(context=context)
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py", line 6014, in _create_CIM_ports
    proj = MappingProjection(
           ^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/pathway/mappingprojection.py", line 483, in __init__
    super().__init__(sender=sender,
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/projection.py", line 776, in __init__
    super(Projection_Base, self).__init__(
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/component.py", line 1261, in __init__
    function = self._instantiate_attributes_before_function(function=function, context=context) or function
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/projection.py", line 867, in _instantiate_attributes_before_function
    self._instantiate_parameter_ports(function=function, context=context)
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/pathway/mappingprojection.py", line 516, in _instantiate_parameter_ports
    function=AccumulatorIntegrator(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<@beartype(psyneulink.core.components.functions.stateful.integratorfunctions.AccumulatorIntegrator.__init__) at 0x7f1a721af420>", line 87, in __init__
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/stateful/integratorfunctions.py", line 565, in __init__
    super().__init__(
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<@beartype(psyneulink.core.components.functions.stateful.integratorfunctions.IntegratorFunction.__init__) at 0x7f1aa2345da0>", line 88, in __init__
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/stateful/integratorfunctions.py", line 242, in __init__
    super().__init__(
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py", line 742, in wrapper
    return func(*args, context=context, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<@beartype(psyneulink.core.components.functions.stateful.statefulfunction.StatefulFunction.__init__) at 0x7f1a721ac680>", line 88, in __init__
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/stateful/statefulfunction.py", line 251, in __init__
    super().__init__(
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/function.py", line 690, in __init__
    super().__init__(
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py", line 506, in check_user_specified_wrapper
    return func(self, *args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/component.py", line 1150, in __init__
    self._initialize_parameters(
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/components/component.py", line 2387, in _initialize_parameters
    elif not contains_type(val, Function):
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/utilities.py", line 2064, in contains_type
    if isinstance(a, typ) or (a is not arr and recurse and contains_type(a, typ)):
                                                           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/psyneulink/core/globals/utilities.py", line None, in contains_type
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 2176, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1182, in structured_traceback
    return FormattedTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1053, in structured_traceback
    return VerboseTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/ultratb.py", line 861, in structured_traceback
    formatted_exceptions: list[list[str]] = self.format_exception_as_a_whole(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/ultratb.py", line 773, in format_exception_as_a_whole
    frames.append(self.format_record(record))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/IPython/core/ultratb.py", line 536, in format_record
    assert isinstance(frame_info.lineno, int)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Let’s see how the EMComposition looks like:

em.show_graph(output_fmt='jupyter')

Input, Context, and Output Layers#

Next, we “hook” up the EMComposition to the input, output and context layer.

EGO

We start with defining the layers

🎯 Exercise 4

Before defining the layers, make sure you understand the in and output of the model:

  • Although the episodic memory composition has three “memory slot”, our training set only consists of a stream of a single state. How can we use this single state

state_input_layer = pnl.ProcessingMechanism(name=state_name, input_shapes=state_size)

previous_state_layer = pnl.ProcessingMechanism(name=previous_state_name, input_shapes=state_size)

context_layer = pnl.TransferMechanism(name=context_name,
                                  input_shapes=state_size,
                                  function=pnl.Tanh,
                                  integrator_mode=True,
                                  integration_rate=.69)

# The output layer:
prediction_layer = pnl.ProcessingMechanism(name='PREDICTION', input_shapes=state_size)

After defining the layers, we need to specify the pathways between the layers. Before looking at the code below, think about which pathways (if any) are learned and which ones are fixed.

# Names for the input nodes of the EMComposition have the form: <node_name> + ' [QUERY]' or <node_name> + ' [VALUE]' or <node_name> + ' [RETRIEVED]' (see above)

QUERY = ' [QUERY]'
VALUE = ' [VALUE]'
RETRIEVED = ' [RETRIEVED]'

# Pathways
state_to_previous_state_pathway = [state_input_layer,
                                   pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX,
                                                         learnable=False),
                                   previous_state_layer]
state_to_context_pathway = [state_input_layer,
                            pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX,
                                                  learnable=False),
                            context_layer]
state_to_em_pathway = [state_input_layer,
                       pnl.MappingProjection(sender=state_input_layer,
                                             receiver=em.nodes[state_name + VALUE],
                                             matrix=pnl.IDENTITY_MATRIX,
                                             learnable=False),
                       em]
previous_state_to_em_pathway = [previous_state_layer,
                                pnl.MappingProjection(sender=previous_state_layer,
                                                      receiver=em.nodes[previous_state_name + QUERY],
                                                      matrix=pnl.IDENTITY_MATRIX,
                                                      learnable=False),
                                em]
context_learning_pathway = [context_layer,
                            pnl.MappingProjection(sender=context_layer,
                                                  matrix=pnl.IDENTITY_MATRIX,
                                                  receiver=em.nodes[context_name + QUERY],
                                                  learnable=True),
                            em,
                            pnl.MappingProjection(sender=em.nodes[state_name + RETRIEVED],
                                                  receiver=prediction_layer,
                                                  matrix=pnl.IDENTITY_MATRIX,
                                                  learnable=False),
                            prediction_layer]

Now, we can create the composition

learning_rate = .5
loss_spec = pnl.Loss.BINARY_CROSS_ENTROPY
model_name = 'EGO'
device = pnl.CPU

ego_model = pnl.AutodiffComposition([state_to_previous_state_pathway,
                                    state_to_context_pathway,
                                    state_to_em_pathway,
                                    previous_state_to_em_pathway,
                                    context_learning_pathway],
                                   learning_rate=.5,
                                   loss_spec=pnl.Loss.BINARY_CROSS_ENTROPY,
                                   name='EGO',
                                   device=pnl.CPU)


ego_model.show_graph(output_fmt='jupyter')

We also need to specify the learning pathway, which can be inferred from the paramaters we have set (setting the target in EMComposition and setting the context to em pathway as learnable):

learning_components = ego_model.infer_backpropagation_learning_pathways(pnl.ExecutionMode.PyTorch)

ego_model.add_projection(pnl.MappingProjection(sender=state_input_layer,
                                              receiver=learning_components[0],
                                              learnable=False))

We also have to make sure the em is executed before the previous state and the context layer:

ego_model.scheduler.add_condition(em, pnl.BeforeNodes(previous_state_layer, context_layer))

Now, we are set to run the model:

trials = state_sequence(BLOCKED, 800, 200)

ego_model.learn(inputs={state_name: trials},
                    learning_rate=.5,
                    execution_mode= pnl.ExecutionMode.PyTorch,
                  )
import matplotlib.pyplot as plt
import numpy as np

TOTAL_NUM_STIMS = len(trials)
TARGETS = np.array(trials[1:] + [one_hot_encode(0, 11)])
curriculum_type = BLOCKED

fig, axes = plt.subplots(1, 1, figsize=(12, 5))
# L1 of loss
axes.plot((np.abs(ego_model.results[1:TOTAL_NUM_STIMS, 2] - TARGETS[:TOTAL_NUM_STIMS - 1])).sum(-1))
axes.set_xlabel('Stimuli')
axes.set_ylabel('Loss')

plt.suptitle(f"{curriculum_type} Training")
plt.show()

🎯 Exercise 4

Run the model for the interleaved paradigm. What do you expect? Compare the two results and explain the differences.