Experiment configuration with Pydantic and Exca#

This example shows how to use the pydantic and exca libraries to configure and run EEG experiments with Braindecode.

Pydantic is a library for data validation and settings management using Python type annotations. It allows defining structured configurations that can be validated and serialized easily.

Exca builds on top of Pydantic, and allows you to seamlessly EXecute experiments and CAche their results.

Braindecode implements a Pydantic configuration for each of its models in braindecode.models.config. In this example, we will use these configurations to define an experiment that trains and evaluates different models on a motor-imagery dataset using Exca.

# Authors: Pierre Guetschel
#
# License: BSD (3-clause)

Creating the experiment configurations#

We will start by defining the configurations needed for our experiment using Pydantic and Exca.

Dataset configs#

Our first configuration class is related to the data. It will allow us to load and prepare the dataset.

import warnings
from typing import Annotated, Literal

import exca
import pydantic
from moabb.datasets.utils import dataset_list

from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import create_windows_from_events

warnings.simplefilter("ignore")

# The list of available MOABB datasets:
DATASET_NAMES = tuple(ds.__name__ for ds in dataset_list)


class WindowedMOABBDatasetConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="forbid")
    dataset_type: Literal["moabb"] = "moabb"
    infra: exca.TaskInfra = exca.TaskInfra(
        folder=None,  # no disk caching
        cluster=None,  # local execution
        keep_in_ram=True,
    )
    dataset_name: Literal[DATASET_NAMES] = "BNCI2014_001"
    subject_id: list[int] | int | None = None
    window_size_seconds: float = 4.0
    overlap_seconds: float = 0.0

    @infra.apply
    def create_instance(self) -> MOABBDataset:
        # We don't apply any preprocessing here for simplicity, but in a real experiment,
        # you would typically want to filter the data, resample it, etc.
        # Instead, our config  directly extracts windows from the raw data.
        dataset = MOABBDataset(
            dataset_name=self.dataset_name, subject_ids=self.subject_id
        )
        windows_dataset = create_windows_from_events(dataset, preload=True)

        return windows_dataset

We can see that the config has an infra: exca.TaskInfra attribute, and a method decorated with @infra.apply. This means that, when called, exca will cache the result of this method. Here, the cache is kept in RAM for simplicity (folder=None), but in a real experiment, you would typically want to cache the results on disk, as shown in the training config. If the method is called again with the same configuration, the cached results will be returned instead of re-running the method. This allows for easy and efficient experimentation.

Additionally, we define a small wrapper config to split the dataset into training and testing sets. Here, no caching is applied since the split operation is fast.

class DatasetSplitConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="forbid")
    dataset_type: Literal["split"] = "split"
    dataset: WindowedMOABBDatasetConfig
    key: str
    by: str = "session"

    def create_instance(self):
        dataset = self.dataset.create_instance()
        splitted = dataset.split(self.by)
        return splitted[self.key]

Finally, we define a union type for dataset configurations, which can be either a WindowedMOABBDatasetConfig or a DatasetSplitConfig.

DatasetConfig = Annotated[
    WindowedMOABBDatasetConfig | DatasetSplitConfig,
    pydantic.Field(discriminator="dataset_type"),
]

Training config#

Now that out data configs are ready, we can define our training config. It will require both the dataset and model configurations. It will simply load the data, instantiate the model, and train the model on the data.

from skorch.callbacks import EarlyStopping
from skorch.dataset import ValidSplit
from torch.optim import Adam

from braindecode.models.config import BraindecodeModelConfig


class TrainingConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="forbid")
    infra: exca.TaskInfra = exca.TaskInfra(
        folder=".cache/",
        cluster=None,  # local execution
    )
    model: BraindecodeModelConfig
    train_dataset: DatasetConfig
    max_epochs: int = 50
    batch_size: int = 32
    lr: float = 0.001
    seed: int = 12

    @infra.apply
    def train(self) -> EEGClassifier:
        # Load training data
        train_set = self.train_dataset.create_instance()
        train_y = train_set.get_metadata()["target"].to_numpy()

        # Instantiate the model
        model = self.model.create_instance()
        clf = EEGClassifier(
            model,
            max_epochs=self.max_epochs,
            batch_size=self.batch_size,
            lr=self.lr,
            train_split=ValidSplit(0.2, random_state=self.seed, stratified=True),
            callbacks=["accuracy", EarlyStopping(patience=3)],
            optimizer=Adam,
        )

        # Train the model
        clf.fit(train_set, train_y)
        return clf.module_.state_dict()

We note that the model has type braindecode.models.config.BraindecodeModelConfig. This type can match all the braindecode model configurations defined in braindecode.models.config.

We also see that there is now a cache folder specified (.cache/ here). This means that the results of the train() method will be cached on disk in this folder, instead of only in RAM.

Evaluation config#

Finally, we define an evaluation config that will load the validation data, load the trained model from the training config, and evaluate it on the validation data.

class EvaluationConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="forbid")
    infra: exca.TaskInfra = exca.TaskInfra(
        folder=".cache/",
        cluster=None,  # local execution
    )
    test_dataset: DatasetConfig
    trainer: TrainingConfig

    @infra.apply
    def evaluate(self) -> float:
        # Load validation data
        valid_set = self.test_dataset.create_instance()
        test_y = valid_set.get_metadata()["target"].to_numpy()

        # Load trained model
        state_dict = self.trainer.train()
        model = self.trainer.model.create_instance()
        model.load_state_dict(state_dict)
        clf = EEGClassifier(model)
        clf.initialize()

        # Evaluate the model
        score = clf.score(valid_set, test_y)
        return score

Note

SLURM execution. Exca also offers the possibility to run experiments remotely on a SLURM-managed cluster. In this example, we run everything locally by setting cluster=None but you can find more information about how to set up cluster execution in the Exca documentation: https://facebookresearch.github.io/exca/infra/introduction.html.

Instantiating the configurations#

Instantiation option 1: from class constructors#

Now that our configuration classes are defined, we can instantiate them.

We will start with the model configuration. Here, we use the braindecode.models.EEGNet model. Like any other braindecode model, it has a corresponding configuration class in braindecode.models.config, called braindecode.models.config.EEGNetConfig. We instantiate it using the signal properties we extracted earlier.

from braindecode.models.config import EEGConformerConfig, EEGNetConfig

signal_kwargs = {"n_times": 1000, "n_chans": 26, "n_outputs": 4}
model_cfg = EEGNetConfig(**signal_kwargs)

The config object can easily be serialized to a JSON format:

print(model_cfg.model_dump(mode="json"))
{'model_name_': 'EEGNet', 'n_chans': 26, 'n_outputs': 4, 'n_times': 1000, 'final_conv_length': 'auto', 'pool_mode': 'mean', 'F1': 8, 'D': 2, 'F2': None, 'kernel_length': 64, 'depthwise_kernel_length': 16, 'pool1_kernel_size': 4, 'pool2_kernel_size': 8, 'conv_spatial_max_norm': 1, 'activation': 'torch.nn.modules.activation.ELU', 'batch_norm_momentum': 0.01, 'batch_norm_affine': True, 'batch_norm_eps': 0.001, 'drop_prob': 0.25, 'final_layer_with_constraint': False, 'norm_rate': 0.25, 'chs_info': None, 'input_window_seconds': None, 'sfreq': None}

Alternatively, if you only want the non-default keys:

print(model_cfg.model_dump(exclude_defaults=True))
{'n_chans': 26, 'n_outputs': 4, 'n_times': 1000}

The config class is checking the arguments types and values, and raises an error if something is wrong. For example, if we try to instantiate it using an incorrect type for n_times, we get an error:

# kept for restoration later:
true_n_times = signal_kwargs["n_times"]

# float instead of int:
signal_kwargs["n_times"] = 22.5

try:
    EEGNetConfig(**signal_kwargs)
except pydantic.ValidationError as e:
    print(f"Validation error raised as expected:\n{e}")
Validation error raised as expected:
1 validation error for EEGNetConfig
n_times
  Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=22.5, input_type=float]
    For further information visit https://errors.pydantic.dev/2.12/v/int_from_float

Similarly, if a mandatory argument is missing, we get an error:

del signal_kwargs["n_times"]
try:
    EEGNetConfig(**signal_kwargs)
except pydantic.ValidationError as e:
    print(f"Validation error raised as expected:\n{e}")

# We restore the correct value for ``n_times`` for the rest of the example:
signal_kwargs["n_times"] = true_n_times
Validation error raised as expected:
1 validation error for EEGNetConfig
  Value error, n_times is required and could not be inferred.Either specify n_times or input_window_seconds and sfreq. [type=value_error, input_value={'n_chans': 26, 'n_outputs': 4}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.12/v/value_error

We now have instantiated the model configuration. Creating the dataset, training and evaluation configurations is very similar and straightforward using the classes we defined earlier.

dataset_cfg = WindowedMOABBDatasetConfig(subject_id=1)

train_dataset_cfg = DatasetSplitConfig(dataset=dataset_cfg, key="0train")
test_dataset_cfg = DatasetSplitConfig(dataset=dataset_cfg, key="1test")

train_cfg = TrainingConfig(model=model_cfg, train_dataset=train_dataset_cfg)

eval_cfg = EvaluationConfig(trainer=train_cfg, test_dataset=test_dataset_cfg)

Instantiation option 2: from nested dictionaries or JSON files#

Alternatively, we can also instantiate the configurations from nested dictionaries or JSON files. This can be useful when loading configurations from external sources. Suppose we have the following JSON configuration for our evaluation. We can load it as a nested dictionary using the json module:

import json

JSON_CFG = """{
    "trainer": {
        "model": {
            "model_name_": "EEGNet",
            "n_times": 1000,
            "n_chans": 26,
            "n_outputs": 4
        },
        "train_dataset": {
            "dataset_type": "split",
            "dataset": {"subject_id": 1},
            "key": "0train"
        }
    },
    "test_dataset": {
        "dataset_type": "split",
        "dataset": {"subject_id": 1},
        "key": "1test"
    }
}"""
NESTED_DICT_CFG = json.loads(JSON_CFG)
print(NESTED_DICT_CFG)
{'trainer': {'model': {'model_name_': 'EEGNet', 'n_times': 1000, 'n_chans': 26, 'n_outputs': 4}, 'train_dataset': {'dataset_type': 'split', 'dataset': {'subject_id': 1}, 'key': '0train'}}, 'test_dataset': {'dataset_type': 'split', 'dataset': {'subject_id': 1}, 'key': '1test'}}

We can instantiate the evaluation configuration from the nested dictionary using the model_validate() method of Pydantic, and check that it is identical to the one we created using the class constructors:

eval_cfg_from_dict = EvaluationConfig.model_validate(NESTED_DICT_CFG)
assert eval_cfg_from_dict == eval_cfg

Serializing the experiment configuration#

To serialize the experiment’s configuration, we can take advantage of Exca’s config() method, which is similar to Pydantic’s model_dump() method but will ensure that an experiment has a unique identifier (UID). In particular, it will also include the "model_name_" field, which will allow us to distinguish between different model configurations later on.

print(eval_cfg.infra.config(uid=True, exclude_defaults=True))
{'test_dataset': {'dataset': {'subject_id': 1}, 'key': '1test', 'dataset_type': 'split'}, 'trainer': {'model': {'n_chans': 26, 'n_outputs': 4, 'n_times': 1000, 'model_name_': 'EEGNet'}, 'train_dataset': {'dataset': {'subject_id': 1}, 'key': '0train', 'dataset_type': 'split'}}}

Running the experiment#

Intermediate results are cached thanks to Exca#

We can now run the training using the configurations we defined. For this, we simply have to call the train() method of the configuration. we will time the execution to see the benefits of caching.

import time

t0 = time.time()
train_cfg.train()
t1 = time.time()

print(f"Training took {t1 - t0:0.2f} seconds")
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.2522        1.3885       0.2414            0.2414        1.3864  0.7054
      2            0.2522        1.3822       0.2414            0.2414        1.3864  0.6864
      3            0.2522        1.3804       0.2414            0.2414        1.3864  0.6839
Stopping since valid_loss has not improved in the last 3 epochs.
Training took 4.95 seconds

If we call the train() method again, using the same configuration parameters, even if it is a new instance, the results will be loaded from the cache:

train_cfg = TrainingConfig(
    model=EEGNetConfig(**signal_kwargs), train_dataset=train_dataset_cfg
)

t0 = time.time()
train_cfg.train()
t1 = time.time()

print(f"Rerunning training using cached results took {t1 - t0:0.4f} seconds")
Rerunning training using cached results took 0.1795 seconds

We can run the evaluation in the same way, by calling the evaluate() method of the evaluation configuration. Internally, this method calls the train() method of the training configuration, which will also use the cache if available.

t0 = time.time()
score = eval_cfg.evaluate()
t1 = time.time()

print(f"Evaluation score: {score}")
print(f"Evaluation took {t1 - t0:0.2f} seconds")
Evaluation score: 0.25
Evaluation took 0.42 seconds

Scaling up: comparing multiple model configurations#

Now that we have seen how to define and run an experiment using Pydantic and Exca, we can easily scale up to compare multiple model configurations.

First, let’s define a small utility function to flatten nested dictionaries. This will help us later when we want to log results from different configurations. See in the example below, the keys of different levels are concatenated with a dot “.” separator.

def flatten_nested_dict(d, leaf_types=(int, float, str, bool), sep="."):
    def aux(d, parent_key):
        out = {}
        for k, v in d.items():
            if isinstance(v, dict):
                out.update(aux(v, parent_key + k + sep))
            elif isinstance(v, leaf_types):
                out[parent_key + k] = v
        return out

    return aux(d, "")


flatten_nested_dict({"a": 1, "b": {"x": 1, "y": {"z": 2}}, "c": [4, 5]})
{'a': 1, 'b.x': 1, 'b.y.z': 2}

In a real experiment, we would launch all runs in parallel on a different nodes of a compute cluster. Please refer to the Exca documentation for more details on how to set up cluster execution. Here, for simplicity, we will just run them locally and sequentially.

In this mini-example, we will compare the EEGNet and EEGConformer models on the same dataset, with multiple random seeds.

model_cfg_list = [
    EEGNetConfig(**signal_kwargs),
    EEGConformerConfig(**signal_kwargs),
]

results = []
for model_cfg in model_cfg_list:
    for seed in [1, 2, 3]:
        train_cfg = TrainingConfig(
            model=model_cfg,
            train_dataset=train_dataset_cfg,
            max_epochs=10,
            lr=0.1,
            seed=seed,
        )
        eval_cfg = EvaluationConfig(trainer=train_cfg, test_dataset=test_dataset_cfg)

        # log configuration
        row = flatten_nested_dict(
            eval_cfg.infra.config(uid=True, exclude_defaults=True)
        )
        # evaluate and log accuracy:
        row["accuracy"] = eval_cfg.evaluate()
        results.append(row)
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.2478        3.1830       0.2586            0.2586        2.7699  0.8317
      2            0.2522        3.1265       0.2414            0.2414        5.9990  0.6849
      3            0.2522        2.5580       0.2414            0.2414        3.3327  0.6863
Stopping since valid_loss has not improved in the last 3 epochs.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.2478        3.4241       0.2586            0.2586       10.4719  0.6863
      2            0.2522        3.2196       0.2414            0.2414       43.7188  0.6871
      3            0.2478        2.2397       0.2586            0.2586       26.4258  0.6866
Stopping since valid_loss has not improved in the last 3 epochs.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.2478        3.4435       0.2586            0.2586        4.2782  0.6886
      2            0.2522        2.5342       0.2414            0.2414        1.9056  0.6866
      3            0.2478        1.9421       0.2586            0.2586        1.6714  0.6880
      4            0.2478        1.6871       0.2586            0.2586        2.0010  0.6816
      5            0.2478        1.3774       0.2586            0.2586        2.9307  0.6895
Stopping since valid_loss has not improved in the last 3 epochs.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.2522     7837.7184       0.2414            0.2414    33044.5386  3.1137
      2            0.2522     4397.0301       0.2414            0.2414   156564.7209  3.0734
      3            0.2478      163.2298       0.2586            0.2586        2.7673  3.0684
      4            0.2478      143.6231       0.2586            0.2586        2.1943  3.0576
      5            0.2478        2.5060       0.2586            0.2586        1.7240  3.1515
      6            0.2522        2.1955       0.2414            0.2414        1.4974  3.1057
      7            0.2478      458.7910       0.2586            0.2586        1.4635  3.1380
      8            0.2478     2742.4738       0.2586            0.2586        1.4702  3.0953
      9            0.2478     7917.7216       0.2586            0.2586        1.5098  3.1435
Stopping since valid_loss has not improved in the last 3 epochs.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.2478    12232.1578       0.2586            0.2586    13802.9802  3.0060
      2            0.2522     2139.0917       0.2414            0.2414        4.9605  3.0351
      3            0.2522        4.9978       0.2414            0.2414        4.8974  2.9969
      4            0.2522        4.9193       0.2414            0.2414        4.5607  3.0109
      5            0.2522        4.4134       0.2414            0.2414        4.1670  3.0273
      6            0.2522        4.5307       0.2414            0.2414        3.8268  3.0282
      7            0.2522        4.3506       0.2414            0.2414        3.5503  3.0150
      8            0.2522        3.8486       0.2414            0.2414        3.2875  3.0109
      9            0.2522        3.8191       0.2414            0.2414        3.0386  3.0120
     10            0.2522      135.1870       0.2414            0.2414        2.7984  3.0064
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.2522     6004.3131       0.2414            0.2414    16271.5028  2.9908
      2            0.2478     7747.1095       0.2586            0.2586        3.6436  3.0032
      3            0.2478        3.7025       0.2586            0.2586        3.1256  3.0013
      4            0.2522        3.7168       0.2414            0.2414        2.9819  3.0362
      5            0.2522        3.7040       0.2414            0.2414        2.8630  3.0200
      6            0.2478        3.4228       0.2586            0.2586        2.7126  3.0163
      7            0.2478        3.3238       0.2586            0.2586        2.4859  3.0389
      8            0.2478        3.0128       0.2586            0.2586        2.2308  3.0360
      9            0.2478        3.0694       0.2586            0.2586        1.9870  3.0158
     10            0.2478        2.7422       0.2586            0.2586        1.7605  3.0184

Gathering and displaying the results#

Loading results from cache#

If experiments were done on a cluster, a likely scenario would be to first run all experiments, and then later load and analyze the results.

Loading the results from cache is straightforward using Exca. We simply need to re-instantiate the configurations with the same parameters, and call the evaluate() method again. The cached results will be loaded in a few seconds instead of re-running the experiments:

del results  # oups, we forgot the results...

t0 = time.time()
results = []
for model_cfg in model_cfg_list:
    for seed in [1, 2, 3]:
        train_cfg = TrainingConfig(
            model=model_cfg,
            train_dataset=train_dataset_cfg,
            max_epochs=10,
            lr=0.1,
            seed=seed,
        )
        eval_cfg = EvaluationConfig(trainer=train_cfg, test_dataset=test_dataset_cfg)

        # log configuration
        row = flatten_nested_dict(
            eval_cfg.infra.config(uid=True, exclude_defaults=True)
        )
        # evaluate and log accuracy:
        row["accuracy"] = eval_cfg.evaluate()
        results.append(row)
t1 = time.time()

print(f"Loading all results from cache took {t1 - t0:0.2f} seconds")
Loading all results from cache took 1.46 seconds

Displaying the results#

Finally, we can concatenate and display the results using pandas:

import pandas as pd

results_df = pd.DataFrame(results)
print(results_df)
   test_dataset.dataset.subject_id test_dataset.key  ... trainer.seed  accuracy
0                                1            1test  ...            1      0.25
1                                1            1test  ...            2      0.25
2                                1            1test  ...            3      0.25
3                                1            1test  ...            1      0.25
4                                1            1test  ...            2      0.25
5                                1            1test  ...            3      0.25

[6 rows x 14 columns]

Or first aggregated over seeds:

agg_results_df = results_df.groupby("trainer.model.model_name_").agg(
    {"accuracy": ["mean", "std"]}
)
print(agg_results_df)
                          accuracy
                              mean  std
trainer.model.model_name_
EEGConformer                  0.25  0.0
EEGNet                        0.25  0.0

Total running time of the script: (2 minutes 25.374 seconds)

Estimated memory usage: 2389 MB

Gallery generated by Sphinx-Gallery