6.2 Episodic Generalization Optimization - EGO#

Introduction#

Human cognition is unique in its ability to perform a wide range of tasks and to learn new tasks quickly. Both abilities have long been associated with the acquisition of knowledge that can generalize across tasks and the flexible use of that knowledge to execute goal-directed behavior. In this tutorial, we introduce how this can emerge in a neural network by implementing the Episodic Generalization and Optimization (EGO) framework. The framework consists of an episodic memory module, which rapidly learns relationships between stimuli; a semantic pathway, which more slowly learns how stimuli map to responses; and a recurrent context module, which maintains a representation of task-relevant context information, integrates this over time, and uses it to recall context-relevant memories.

EGO

The EGO framework consists of a control mechanism (context module; upper middle) and an episodic memory mechanism (bottom left). Episodic memory records conjunctions of stimuli (blue boxes), contexts (pink boxes), and observed responses (green boxes) at each time point (rows). Bidirectional arrows connect episodic memory to the stimulus, context, and output, indicating that these values can be stored in or used to query episodic memory, or retrieved from it when another field is queried. You can think of this as a more flexible dictionary that stores triplets instead of distinct key-value pairs, and allows any field (or any combinations of fields) to act as a key. The context module integrates previous context (recurrent connection) along with information about the stimulus and the context retrieved from memory.

Here we show that the EGO framework can emulate human behavior in a specific learning environment where participants are trained on two sets of sequences involving identical states presented in different orders for different contexts. Empirical findings show that participants perform better when trained in blocks of each context than when trained interleaved:

Task: Coffe Shop World (CSW)#

suspicious barrista caffe graditude

Imagine, you are in a city with two coffee shops, each with a different layout and different ways of ordering. In one coffee shop—called The Suspicious Barista—you order first, pay for the coffee, and then sit down to wait until the waiter brings your order. In the other coffee shop—called Café Gratitude—you sit down first, wait until the waiter comes and takes your order. You pay after finishing the coffee.

This example demonstrates that many situations share similar stimuli but have different transition structures. Simple integration will help the system learn the transition structure, but it will only provide a weak cue about the difference between them due to the similarity between the situations. In other words the states –ordering, paying, and sitting down– are very similar between the two situations and are therefore hard to distinguish. This can be overcome by differentiating the context representations associated with each setting (e.g., learning different context representations for coffee shops with paranoid vs. gullible baristas). Recent empirical work suggests that people can learn how to do this very effectively, but that this depends on the temporal structure of the environment: people do better when trained in blocks of each situation than when trained interleaved (Beukers et al., 2023).

We start with creating a dataset for the CSW task.

Installation and Setup

If the following cell fails to execute, please restart the kernel (or session) and run the cell again. This is a known issue when running in google colab.

%%capture
%pip install psyneulink

import psyneulink as pnl
import random

Generating data for the CSW task#

We start by generating a dataset for the CSW task. The dataset consists of sequences of states. The task is to predict the next state given the current state and the context. The transition between states is determined by the context which in turn is determined by the “first” state in the sequence. The following figure illustrates the task structure:

EGO

On the left side of the figure, you can see the task structure:

The two colors represent different contexts: blue and orange.

  • If the first observed state in a sequence is 0, the participant is in the blue context.

    • The next state can be either 1 or 2.

    • From then on, transitions are deterministic:

      • 1 → 3 → 5 → 7

      • 2 → 4 → 6 → 8

  • If the first observed state is 9, the participant is in the orange context.

    • The sequence starts with either 1 or 2, but follows a different transition pattern:

      • 1 → 4 → 5 → 8

      • 2 → 3 → 6 → 7

The right side of the figure shows the different learning paradigms:

In the blocked paradigm, participants are trained on blocks of the same context. In the interleaved paradigm, participants are trained on a mix of contexts. In the test paradigm, participants are tested on a sequence of random contexts.

We start with defining a function that generates a context-specific sequence:

def gen_context(
    context: int,
    start_state: int,
):
    """
    Generate a context-specific sequence.
    Args:
        context (int): The context to generate the sequence for. (0 or 9)
        start_state (int): The first state in the sequence. (1 or 2)
    """
    seq = [context, start_state]
    if context == 0:
        for _ in range(3):
            seq.append(seq[-1] + 2)
    elif context == 9:
        for _ in range(3):
            seq.append(seq[-1] + 1 if seq[-1] % 2 == 0 else seq[-1] + 3)
    return seq

"""Test the function"""
assert gen_context(0, 1) == [0, 1, 3, 5, 7]
assert gen_context(9, 2) == [9, 2, 3, 6, 7]

Generate a full dataset for the CSW task. Now, let’s create a function that returns the full trial sequence for a given paradigm and number of samples.

# Define the paradigms
BLOCKED = 'blocked'
INTERLEAVED = 'interleaved'


def gen_context_sequences(
        paradigm: str,
        train_contexts: int,
        test_contexts: int,
        block_size: int = 4,
):
    """
    Generate a dataset for the CSW task.
    Args:
        paradigm (str): The paradigm to generate the dataset for. (blocked or interleaved)
        train_contexts (int): The number of training contexts.
        test_contexts (int): The number of test contexts.
        block_size (int): The size of each block in the blocked paradigm.
    """
    assert train_contexts % block_size == 0, "The number of training samples must be a multiple of block_size."
    x = []
    if paradigm == INTERLEAVED:
        for idx in range(train_contexts):
            if idx % 2: # odd contexts -> context 0
                x += [gen_context(0, random.randint(1, 2))]
            else: # even contexts -> context 9
                x += [gen_context(9, random.randint(1, 2))]

    if paradigm == BLOCKED:
        for i in range(block_size): # block_size number of blocks
            if i % 2: # odd blocks -> context 0
                for _ in range(train_contexts // block_size):
                    x += [gen_context(0, random.randint(1, 2))]
            else: # even blocks -> context 9
                for _ in range(train_contexts // block_size):
                    x += [gen_context(9, random.randint(1, 2))]

    for _ in range(test_contexts):
        x += [gen_context(random.choice([0, 9]), random.randint(1, 2))]
    return x


context_sequences = gen_context_sequences(BLOCKED, 8, 4)
context_sequences
[[9, 1, 4, 5, 8],
 [9, 1, 4, 5, 8],
 [0, 1, 3, 5, 7],
 [0, 2, 4, 6, 8],
 [9, 1, 4, 5, 8],
 [9, 2, 3, 6, 7],
 [0, 1, 3, 5, 7],
 [0, 1, 3, 5, 7],
 [9, 2, 3, 6, 7],
 [0, 1, 3, 5, 7],
 [0, 1, 3, 5, 7],
 [0, 1, 3, 5, 7]]

The structure of the generated sequence is not “realistic” yet. The participant doesn’t see distinct contexts but rather states. We need to “flatten” the sequence. Also, we instead of using integers to represent the states, we will use one-hot encoding:

def one_hot_encode(
        label: int,
        num_classes: int):
    """
    One hot encode a label (integer)
    Args:
        label (int): The label to encode (between 0 and num_classes-1)
        num_classes (int): The number of classes
    """
    return [1 if i == label else 0 for i in range(num_classes)]


def state_sequence(
        paradigm: str,
        train_trials: int,
        test_trials: int,
        context_length: int = 5,
        block_size: int = 4,
):
    """
    Generate a dataset for the CSW task.
    Args:
        paradigm (str): The paradigm to generate the dataset for. (blocked or interleaved)
        train_trials (int): The number of training trials.
        test_trials (int): The number of test trials.
        context_length (int): The length of the context.
        block_size (int): The size of each block in the blocked paradigm.
    """


    assert train_trials % context_length == 0, "The number of training samples must be a multiple of context_length."
    assert test_trials % context_length == 0, "The number of test samples must be a multiple of context_length."

    train_contexts = train_trials // context_length
    test_contexts = test_trials // context_length

    train_context_sequences = gen_context_sequences(
        paradigm, train_contexts, test_contexts, block_size
    )

    states = []
    for context_sequence in train_context_sequences:
        for state_int in context_sequence:
            states.append(one_hot_encode(state_int, 11))
    return states


state_sequences = state_sequence(BLOCKED, 20, 5)
state_sequences
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]]

🎯 Exercise 1

Why do we encode the states using one-hot encoding?

✅ Solution 1

One-hot encoding is used for categorical variables. This means states have no inherit “order” or can be compared using arithmetic operations in a meaningful way. One-hot encoding allows this representation as states are “orthogonal” to each other.

🎯 Exercise 2

We want to train the EGO model in a supervised manner but the generated dataset doesn’t allow us to do so. Why is this the case and what do we need to do be able to train the model?

💡 Hint 1

For supervised training, we need to provide a target for each input. Think about what the target should be in this case.

💡 Hint 2

The task in this case, is to predict the next state given the current state.

✅ Solution 2

The target in this case is just the next state in the sequence:

x = state_sequence(BLOCKED, 20, 5)
y = x[1:] + [one_hot_encode(0, 11)] # the last state has no next state and is arbitrary in this case either 0 or 9

The EGO model#

As mentioned earlier, the EGO model consists of three main components: an episodic memory module, a semantic pathway, and a recurrent context module. PsyNeulink provides a EMComposition class that allows us to create the episodic memory module. The EMcomposition class is a subclass of the Composition class. A strength of the PsyNeuLink framework is that it allows fo the creation of complex composition that can be used as mechanism in other compositions. Here, we first look at the EMComposition class in isolation and then integrate it into the EGO model.

Episodic Memory Module - EMComposition#

EM

Here, we initialize the EMComposition for the episodic memory shown above. The EMComposition allows for specifying the structure of the episodic memory. Remember, the task here is to predict the state from the previous state and the context. Therefore, in our case each entry in the memory consists of a triplet of states:

  • The current state (green box)

  • The previous state (blue box)

  • The context (pink box)

Each state is represented as a vector with 11 elements (one hot encoding).

Here, we also specify the specific fields. Fields have three main parameters that have to be specified as a dictionary:

  • FIELD_WEIGHT: The weight of the field when retrieving from memory

  • LEARN_FIELD_WEIGHT: Whether the retrieval field weight should be learned (Here, we won’t learn these weights but set them)

  • TARGET_FIELD: Whether the field is a target field (Meaning it’s “error” is calculated during learning)

🎯 Exercise 3

Before looking at the code below, think about what to set for the FIELD_WEIGHT and the TARGET_FIELD for the three different fields (current state, previous state, and context).

💡 Hint

The FIELD_WEIGHT specifies weather a field should be used during retrieval (and how much it should be used during retrieval). It is a scalar value between 0 and 1. The TARGET_FIELD specifies weather a field is a target field.

✅ Solution

The FIELD_WEIGHT for the current state should be None since it is the target field and shouldn’t be used in retrieval. The FIELD_WEIGHT for both the previous and the context should be set to an equal value (here we set them both to 1). The TARGET_FIELD should be set to True for the current state and False for the previous state and the context.

name = 'EM'  # a name for the EMComposition

# Memory parameters
state_size = 11  # the size of the state vector
memory_capacity = 1000  # here we set the maximum number of entries in the memory (we want to be able to store all 1000 trials)

# Fields

# State field
state_name = 'STATE'
state_retrieval_weight = None  # This entry is not used when retrieving from memory (remember, we want to predict the state)
state_is_target = True

# Previous state field
previous_state_name = 'PREVIOUS STATE'
previous_state_retrieval_weight = .5  # This entry is used when retrieving from memory
previous_state_is_target = False

# Context field
context_name = 'CONTEXT'
context_retrieval_weight = .5  # This entry is used when retrieving from memory
context_is_target = False

em = pnl.EMComposition(name=name,
                       memory_template=[[0] * state_size,  # state
                                        [0] * state_size,  # previous state
                                        [0] * state_size],  # context
                       memory_fill=.001,
                       memory_capacity=memory_capacity,
                       normalize_memories=False,
                       memory_decay_rate=0,  # no decay of memory
                       softmax_gain=10.,
                       softmax_threshold=.001,
                       fields={state_name: {pnl.FIELD_WEIGHT: state_retrieval_weight,
                                            pnl.LEARN_FIELD_WEIGHT: False,
                                            pnl.TARGET_FIELD: True},
                               previous_state_name: {pnl.FIELD_WEIGHT: previous_state_retrieval_weight,
                                                     pnl.LEARN_FIELD_WEIGHT: False,
                                                     pnl.TARGET_FIELD: False},
                               context_name: {pnl.FIELD_WEIGHT: context_retrieval_weight,
                                              pnl.LEARN_FIELD_WEIGHT: False,
                                              pnl.TARGET_FIELD: False}},

                       normalize_field_weights=True,

                       concatenate_queries=False,
                       enable_learning=True,
                       learning_rate=.5,
                       device=pnl.CPU
                       )
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[5], line 24
     21 context_retrieval_weight = .5  # This entry is used when retrieving from memory
     22 context_is_target = False
---> 24 em = pnl.EMComposition(name=name,
     25                        memory_template=[[0] * state_size,  # state
     26                                         [0] * state_size,  # previous state
     27                                         [0] * state_size],  # context
     28                        memory_fill=.001,
     29                        memory_capacity=memory_capacity,
     30                        normalize_memories=False,
     31                        memory_decay_rate=0,  # no decay of memory
     32                        softmax_gain=10.,
     33                        softmax_threshold=.001,
     34                        fields={state_name: {pnl.FIELD_WEIGHT: state_retrieval_weight,
     35                                             pnl.LEARN_FIELD_WEIGHT: False,
     36                                             pnl.TARGET_FIELD: True},
     37                                previous_state_name: {pnl.FIELD_WEIGHT: previous_state_retrieval_weight,
     38                                                      pnl.LEARN_FIELD_WEIGHT: False,
     39                                                      pnl.TARGET_FIELD: False},
     40                                context_name: {pnl.FIELD_WEIGHT: context_retrieval_weight,
     41                                               pnl.LEARN_FIELD_WEIGHT: False,
     42                                               pnl.TARGET_FIELD: False}},
     43 
     44                        normalize_field_weights=True,
     45 
     46                        concatenate_queries=False,
     47                        enable_learning=True,
     48                        learning_rate=.5,
     49                        device=pnl.CPU
     50                        )

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/library/compositions/emcomposition.py:1758, in EMComposition.__init__(self, memory_template, memory_capacity, memory_fill, fields, field_names, field_weights, learn_field_weights, learning_rate, normalize_field_weights, concatenate_queries, normalize_memories, softmax_gain, softmax_threshold, softmax_choice, storage_prob, memory_decay_rate, purge_by_field_weights, enable_learning, target_fields, use_storage_node, use_gating_for_weighting, random_state, seed, name, **kwargs)
   1731 super().__init__(name=name,
   1732                  memory_template = memory_template,
   1733                  memory_capacity = memory_capacity,
   (...)   1751                  **kwargs
   1752                  )
   1754 self._validate_options_with_learning(use_gating_for_weighting,
   1755                                      enable_learning,
   1756                                      softmax_choice)
-> 1758 self._construct_pathways(self.memory_template,
   1759                          self.memory_capacity,
   1760                          self.field_weights,
   1761                          self.concatenate_queries,
   1762                          self.normalize_memories,
   1763                          self.softmax_gain,
   1764                          self.softmax_threshold,
   1765                          self.softmax_choice,
   1766                          self.storage_prob,
   1767                          self.memory_decay_rate,
   1768                          self._use_storage_node,
   1769                          self.learn_field_weights,
   1770                          self.enable_learning,
   1771                          self._use_gating_for_weighting)
   1773 # if torch_available:
   1774 #     from psyneulink.library.compositions.pytorchEMcompositionwrapper import PytorchEMCompositionWrapper
   1775 #     self.pytorch_composition_wrapper_type = PytorchEMCompositionWrapper
   (...)   1778 
   1779 # Assign learning-related attributes
   1780 self._set_learning_attributes()

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/library/compositions/emcomposition.py:2277, in EMComposition._construct_pathways(self, memory_template, memory_capacity, field_weights, concatenate_queries, normalize_memories, softmax_gain, softmax_threshold, softmax_choice, storage_prob, memory_decay_rate, use_storage_node, learn_field_weights, enable_learning, use_gating_for_weighting)
   2275                 pathway.insert(2, self.weighted_match_nodes[i])
   2276             self.add_linear_processing_pathway(pathway)
-> 2277         self.add_linear_processing_pathway([self.combined_matches_node, self.softmax_node])
   2278 # Query-concatenated pathways
   2279 else:
   2280     for i in range(self.num_keys):

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:742, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    739             pass
    741 try:
--> 742     return func(*args, context=context, **kwargs)
    743 except TypeError as e:
    744     # context parameter may be passed as a positional arg
    745     if (
    746         f"{func.__name__}() got multiple values for argument"
    747         not in str(e)
    748     ):

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py:8140, in Composition.add_linear_processing_pathway(self, pathway, default_projection_matrix, name, context, *args)
   8133 pathway = Pathway(pathway=parsed_pathway,
   8134                   composition=self,
   8135                   # default_projection_matrix=default_projection_matrix,
   8136                   name=pathway_name,
   8137                   context=context)
   8138 self.pathways.append(pathway)
-> 8140 self._analyze_graph(context)
   8142 return pathway

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:751, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    745     if (
    746         f"{func.__name__}() got multiple values for argument"
    747         not in str(e)
    748     ):
    749         raise e
--> 751 return func(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py:4312, in Composition._analyze_graph(self, context)
   4310 self._determine_node_roles(context=context)
   4311 self._determine_pathway_roles(context=context)
-> 4312 self._create_CIM_ports(context=context)
   4313 # Call after above so shadow_projections have relevant organization
   4314 self._update_shadow_projections(context=context)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/compositions/composition.py:6014, in Composition._create_CIM_ports(self, context)
   6011 proj_name = "(" + output_port.name + ") to (" + interface_input_port.name + ")"
   6013 # create Projection from the OutputPort of the output Node to InputPort on the output CIM
-> 6014 proj = MappingProjection(
   6015     sender=output_port,
   6016     receiver=interface_input_port,
   6017     # FIX:  This fails if OutputPorts don't all have the same dimensionality (number of axes);
   6018     #       see example in test_output_ports/TestOutputPorts
   6019     matrix=IDENTITY_MATRIX,
   6020     learnable=False,
   6021     name=proj_name
   6022 )
   6024 # activate the projection
   6025 proj._activate_for_compositions(self)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/pathway/mappingprojection.py:483, in MappingProjection.__init__(self, sender, receiver, weight, exponent, matrix, function, learnable, params, name, prefs, context, **kwargs)
    480     self.initialization_status = ContextFlags.DEFERRED_INIT
    482 # Validate sender (as variable) and params
--> 483 super().__init__(sender=sender,
    484                  receiver=receiver,
    485                  weight=weight,
    486                  exponent=exponent,
    487                  matrix=matrix,
    488                  function=function,
    489                  params=params,
    490                  name=name,
    491                  prefs=prefs,
    492                  **kwargs)
    494 try:
    495     self._parameter_ports[MATRIX].function.reset(context=context)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/projection.py:776, in Projection_Base.__init__(self, receiver, sender, weight, exponent, function, feedback, exclude_in_autodiff, params, name, prefs, context, **kwargs)
    772  self._creates_scheduling_dependency = True
    774 # Validate variable, function and params
    775  # Note: pass name of Projection (to override assignment of componentName in super.__init__)
--> 776  super(Projection_Base, self).__init__(
    777      default_variable=variable,
    778      function=function,
    779      param_defaults=params,
    780      weight=weight,
    781      exponent=exponent,
    782      name=self.name,
    783      prefs=prefs,
    784      **kwargs
    785  )
    787  self._assign_default_projection_name()

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/component.py:1261, in Component.__init__(self, default_variable, param_defaults, input_shapes, function, name, reset_stateful_function_when, prefs, function_params, **kwargs)
   1254 self.most_recent_context = context
   1256 # INSTANTIATE ATTRIBUTES BEFORE FUNCTION
   1257 # Stub for methods that need to be executed before instantiating function
   1258 #    (e.g., _instantiate_sender and _instantiate_receiver in Projection)
   1259 # Allow _instantiate_attributes_before_function of subclass
   1260 #    to modify/replace function arg provided in constructor (e.g. TransferWithCosts)
-> 1261 function = self._instantiate_attributes_before_function(function=function, context=context) or function
   1263 # INSTANTIATE FUNCTION
   1264 #    - assign initial function parameter values from ParameterPorts,
   1265 #    - assign function's output to self.defaults.value (based on call of self.execute)
   1266 self._instantiate_function(function=function, function_params=function_params, context=context)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/projection.py:867, in Projection_Base._instantiate_attributes_before_function(self, function, context)
    865 def _instantiate_attributes_before_function(self, function=None, context=None):
--> 867     self._instantiate_parameter_ports(function=function, context=context)
    869     # If Projection has a matrix parameter, it is specified as a keyword arg in the constructor,
    870     #    and sender and receiver have been instantiated, then implement it:
    871     if hasattr(self.parameters, MATRIX) and self.parameters.matrix._user_specified:

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/projections/pathway/mappingprojection.py:516, in MappingProjection._instantiate_parameter_ports(self, function, context)
    511 initial_rate = new_variable * 0.0
    513 # KDM 7/11/19: instead of simply setting the function, we need to reinstantiate to ensure
    514 # new defaults get set properly
    515 self._parameter_ports[MATRIX]._instantiate_function(
--> 516     function=AccumulatorIntegrator(
    517         owner=self._parameter_ports[MATRIX],
    518         default_variable=new_variable,
    519         initializer=new_variable,
    520         # rate=initial_rate
    521     ),
    522     context=context
    523 )
    524 self._parameter_ports[MATRIX]._instantiate_value(context)
    525 self._parameter_ports[MATRIX]._update_parameter_components(context)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File <@beartype(psyneulink.core.components.functions.stateful.integratorfunctions.AccumulatorIntegrator.__init__) at 0x7f96c34d6200>:87, in __init__(__beartype_object_140286320680640, __beartype_get_violation, __beartype_conf, __beartype_object_140287928461952, __beartype_object_140287654521120, __beartype_object_140287654553136, __beartype_object_140287654553200, __beartype_object_140287654521200, __beartype_object_140287654553264, __beartype_object_140287654521280, __beartype_check_meta, __beartype_func, *args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/stateful/integratorfunctions.py:565, in AccumulatorIntegrator.__init__(self, default_variable, rate, increment, noise, initializer, params, owner, prefs)
    553 @check_user_specified
    554 @beartype
    555 def __init__(self,
   (...)    562              owner=None,
    563              prefs:  Optional[ValidPrefSet] = None):
--> 565     super().__init__(
    566         default_variable=default_variable,
    567         rate=rate,
    568         increment=increment,
    569         noise=noise,
    570         initializer=initializer,
    571         params=params,
    572         owner=owner,
    573         prefs=prefs,
    574     )

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File <@beartype(psyneulink.core.components.functions.stateful.integratorfunctions.IntegratorFunction.__init__) at 0x7f96d08a3920>:88, in __init__(__beartype_object_140286320680640, __beartype_get_violation, __beartype_conf, __beartype_object_140287928461952, __beartype_object_140287654521120, __beartype_object_140287654553136, __beartype_object_140287654553200, __beartype_object_140287654521200, __beartype_object_140287654553264, __beartype_object_140287654521280, __beartype_args_name_keywordable, __beartype_check_meta, __beartype_func, *args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/stateful/integratorfunctions.py:242, in IntegratorFunction.__init__(self, default_variable, rate, noise, initializer, params, owner, prefs, context, **kwargs)
    229 @check_user_specified
    230 @beartype
    231 def __init__(self,
   (...)    239              context=None,
    240              **kwargs):
--> 242     super().__init__(
    243         default_variable=default_variable,
    244         initializer=initializer,
    245         rate=rate,
    246         noise=noise,
    247         params=params,
    248         owner=owner,
    249         prefs=prefs,
    250         context=context,
    251         **kwargs
    252     )

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/context.py:742, in handle_external_context.<locals>.decorator.<locals>.wrapper(context, *args, **kwargs)
    739             pass
    741 try:
--> 742     return func(*args, context=context, **kwargs)
    743 except TypeError as e:
    744     # context parameter may be passed as a positional arg
    745     if (
    746         f"{func.__name__}() got multiple values for argument"
    747         not in str(e)
    748     ):

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File <@beartype(psyneulink.core.components.functions.stateful.statefulfunction.StatefulFunction.__init__) at 0x7f96d08a3380>:88, in __init__(__beartype_object_140286320680640, __beartype_get_violation, __beartype_conf, __beartype_object_140287928461952, __beartype_object_140287654521120, __beartype_object_140287654553136, __beartype_object_140287654553200, __beartype_object_140287654521200, __beartype_object_140287654553264, __beartype_object_140287654521280, __beartype_args_name_keywordable, __beartype_check_meta, __beartype_func, *args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/stateful/statefulfunction.py:251, in StatefulFunction.__init__(self, default_variable, rate, noise, initializer, params, owner, prefs, context, **kwargs)
    248 if not hasattr(self, "stateful_attributes"):
    249     self.stateful_attributes = ["previous_value"]
--> 251 super().__init__(
    252     default_variable=default_variable,
    253     rate=rate,
    254     initializer=initializer,
    255     noise=noise,
    256     params=params,
    257     owner=owner,
    258     prefs=prefs,
    259     context=context,
    260     **kwargs
    261 )

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/functions/function.py:690, in Function_Base.__init__(self, default_variable, params, owner, name, prefs, context, **kwargs)
    683 register_category(entry=self,
    684                   base_class=Function_Base,
    685                   registry=FunctionRegistry,
    686                   name=name,
    687                   )
    688 self.owner = owner
--> 690 super().__init__(
    691     default_variable=default_variable,
    692     param_defaults=params,
    693     name=name,
    694     prefs=prefs,
    695     **kwargs
    696 )

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/parameters.py:506, in check_user_specified.<locals>.check_user_specified_wrapper(self, *args, **kwargs)
    503     self._prev_constructor = constructor
    505 self._prev_kwargs = kwargs
--> 506 return func(self, *args, **orig_kwargs)

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/component.py:1150, in Component.__init__(self, default_variable, param_defaults, input_shapes, function, name, reset_stateful_function_when, prefs, function_params, **kwargs)
   1144     self.reset_stateful_function_when = Never()
   1146 parameter_values, function_params = self._parse_arguments(
   1147     default_variable, param_defaults, input_shapes, function, function_params, kwargs
   1148 )
-> 1150 self._initialize_parameters(
   1151     context=context,
   1152     **parameter_values
   1153 )
   1155 var = call_with_pruned_args(
   1156     self._handle_default_variable,
   1157     **parameter_values
   1158 )
   1159 if var is None:

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/components/component.py:2387, in Component._initialize_parameters(self, context, **param_defaults)
   2385         if val.owner is not None:
   2386             val = copy.deepcopy(val)
-> 2387     elif not contains_type(val, Function):
   2388         val = copy_parameter_value(val, shared_types=shared_types)
   2389 else:

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/utilities.py:2064, in contains_type(arr, typ)
   2062 recurse = not isinstance(arr, np.matrix)
   2063 for a in arr_items:
-> 2064     if isinstance(a, typ) or (a is not arr and recurse and contains_type(a, typ)):
   2065         return True
   2067 return False

File /opt/hostedtoolcache/Python/3.11.11/x64/lib/python3.11/site-packages/psyneulink/core/globals/utilities.py:2064, in contains_type(arr, typ)
   2062 recurse = not isinstance(arr, np.matrix)
   2063 for a in arr_items:
-> 2064     if isinstance(a, typ) or (a is not arr and recurse and contains_type(a, typ)):
   2065         return True
   2067 return False

File <frozen abc>:119, in __instancecheck__(cls, instance)

KeyboardInterrupt: 

Let’s see how the EMComposition looks like:

em.show_graph(output_fmt='jupyter')

Input, Context, and Output Layers#

Next, we “hook” up the EMComposition to the input, output and context layer.

EGO

We start with defining the layers

🎯 Exercise 4

Before defining the layers, make sure you understand the in and output of the model:

  • Although the episodic memory composition has three “memory slot”, our training set only consists of a stream of a single state. How can we use this single state

state_input_layer = pnl.ProcessingMechanism(name=state_name, input_shapes=state_size)

previous_state_layer = pnl.ProcessingMechanism(name=previous_state_name, input_shapes=state_size)

context_layer = pnl.TransferMechanism(name=context_name,
                                  input_shapes=state_size,
                                  function=pnl.Tanh,
                                  integrator_mode=True,
                                  integration_rate=.69)

# The output layer:
prediction_layer = pnl.ProcessingMechanism(name='PREDICTION', input_shapes=state_size)

After defining the layers, we need to specify the pathways between the layers. Before looking at the code below, think about which pathways (if any) are learned and which ones are fixed.

# Names for the input nodes of the EMComposition have the form: <node_name> + ' [QUERY]' or <node_name> + ' [VALUE]' or <node_name> + ' [RETRIEVED]' (see above)

QUERY = ' [QUERY]'
VALUE = ' [VALUE]'
RETRIEVED = ' [RETRIEVED]'

# Pathways
state_to_previous_state_pathway = [state_input_layer,
                                   pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX,
                                                         learnable=False),
                                   previous_state_layer]
state_to_context_pathway = [state_input_layer,
                            pnl.MappingProjection(matrix=pnl.IDENTITY_MATRIX,
                                                  learnable=False),
                            context_layer]
state_to_em_pathway = [state_input_layer,
                       pnl.MappingProjection(sender=state_input_layer,
                                             receiver=em.nodes[state_name + VALUE],
                                             matrix=pnl.IDENTITY_MATRIX,
                                             learnable=False),
                       em]
previous_state_to_em_pathway = [previous_state_layer,
                                pnl.MappingProjection(sender=previous_state_layer,
                                                      receiver=em.nodes[previous_state_name + QUERY],
                                                      matrix=pnl.IDENTITY_MATRIX,
                                                      learnable=False),
                                em]
context_learning_pathway = [context_layer,
                            pnl.MappingProjection(sender=context_layer,
                                                  matrix=pnl.IDENTITY_MATRIX,
                                                  receiver=em.nodes[context_name + QUERY],
                                                  learnable=True),
                            em,
                            pnl.MappingProjection(sender=em.nodes[state_name + RETRIEVED],
                                                  receiver=prediction_layer,
                                                  matrix=pnl.IDENTITY_MATRIX,
                                                  learnable=False),
                            prediction_layer]

Now, we can create the composition

learning_rate = .5
loss_spec = pnl.Loss.BINARY_CROSS_ENTROPY
model_name = 'EGO'
device = pnl.CPU

ego_model = pnl.AutodiffComposition([state_to_previous_state_pathway,
                                    state_to_context_pathway,
                                    state_to_em_pathway,
                                    previous_state_to_em_pathway,
                                    context_learning_pathway],
                                   learning_rate=.5,
                                   loss_spec=pnl.Loss.BINARY_CROSS_ENTROPY,
                                   name='EGO',
                                   device=pnl.CPU)


ego_model.show_graph(output_fmt='jupyter')

We also need to specify the learning pathway, which can be inferred from the paramaters we have set (setting the target in EMComposition and setting the context to em pathway as learnable):

learning_components = ego_model.infer_backpropagation_learning_pathways(pnl.ExecutionMode.PyTorch)

ego_model.add_projection(pnl.MappingProjection(sender=state_input_layer,
                                              receiver=learning_components[0],
                                              learnable=False))

We also have to make sure the em is executed before the previous state and the context layer:

ego_model.scheduler.add_condition(em, pnl.BeforeNodes(previous_state_layer, context_layer))

Now, we are set to run the model:

trials = state_sequence(BLOCKED, 800, 200)

ego_model.learn(inputs={state_name: trials},
                    learning_rate=.5,
                    execution_mode= pnl.ExecutionMode.PyTorch,
                  )
import matplotlib.pyplot as plt
import numpy as np

TOTAL_NUM_STIMS = len(trials)
TARGETS = np.array(trials[1:] + [one_hot_encode(0, 11)])
curriculum_type = BLOCKED

fig, axes = plt.subplots(1, 1, figsize=(12, 5))
# L1 of loss
axes.plot((np.abs(ego_model.results[1:TOTAL_NUM_STIMS, 2] - TARGETS[:TOTAL_NUM_STIMS - 1])).sum(-1))
axes.set_xlabel('Stimuli')
axes.set_ylabel('Loss')

plt.suptitle(f"{curriculum_type} Training")
plt.show()

🎯 Exercise 4

Run the model for the interleaved paradigm. What do you expect? Compare the two results and explain the differences.