Note
Go to the end to download the full example code.
Experiment configuration with Pydantic and Exca#
This example shows how to use the pydantic and exca libraries
to configure and run EEG experiments with Braindecode.
Pydantic is a library for data validation and settings management using Python type annotations. It allows defining structured configurations that can be validated and serialized easily.
Exca builds on top of Pydantic, and allows you to seamlessly EXecute experiments and CAche their results.
Braindecode implements a Pydantic configuration for each of its models in
braindecode.models.config.
In this example, we will use these configurations to define an experiment that
trains and evaluates different models on a motor-imagery dataset using Exca.
# Authors: Pierre Guetschel
#
# License: BSD (3-clause)
Creating the experiment configurations#
We will start by defining the configurations needed for our experiment using Pydantic and Exca.
Dataset configs#
Our first configuration class is related to the data. It will allow us to load and prepare the dataset.
import warnings
from typing import Annotated, Literal
import exca
import pydantic
from moabb.datasets.utils import dataset_list
from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import create_windows_from_events
warnings.simplefilter("ignore")
# The list of available MOABB datasets:
DATASET_NAMES = tuple(ds.__name__ for ds in dataset_list)
class WindowedMOABBDatasetConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="forbid")
dataset_type: Literal["moabb"] = "moabb"
infra: exca.TaskInfra = exca.TaskInfra(
folder=None, # no disk caching
cluster=None, # local execution
keep_in_ram=True,
)
dataset_name: Literal[DATASET_NAMES] = "BNCI2014_001"
subject_id: list[int] | int | None = None
window_size_seconds: float = 4.0
overlap_seconds: float = 0.0
@infra.apply
def create_instance(self) -> MOABBDataset:
# We don't apply any preprocessing here for simplicity, but in a real experiment,
# you would typically want to filter the data, resample it, etc.
# Instead, our config directly extracts windows from the raw data.
dataset = MOABBDataset(
dataset_name=self.dataset_name, subject_ids=self.subject_id
)
windows_dataset = create_windows_from_events(dataset, preload=True)
return windows_dataset
We can see that the config has an infra: exca.TaskInfra attribute,
and a method decorated with @infra.apply.
This means that, when called, exca will cache the result of this method.
Here, the cache is kept in RAM for simplicity (folder=None), but in a real experiment,
you would typically want to cache the results on disk, as shown in the training config.
If the method is called again with the same configuration, the cached results will be returned instead of re-running the method.
This allows for easy and efficient experimentation.
Additionally, we define a small wrapper config to split the dataset into training and testing sets. Here, no caching is applied since the split operation is fast.
class DatasetSplitConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="forbid")
dataset_type: Literal["split"] = "split"
dataset: WindowedMOABBDatasetConfig
key: str
by: str = "session"
def create_instance(self):
dataset = self.dataset.create_instance()
splitted = dataset.split(self.by)
return splitted[self.key]
Finally, we define a union type for dataset configurations,
which can be either a WindowedMOABBDatasetConfig or a DatasetSplitConfig.
DatasetConfig = Annotated[
WindowedMOABBDatasetConfig | DatasetSplitConfig,
pydantic.Field(discriminator="dataset_type"),
]
Training config#
Now that out data configs are ready, we can define our training config. It will require both the dataset and model configurations. It will simply load the data, instantiate the model, and train the model on the data.
from skorch.callbacks import EarlyStopping
from skorch.dataset import ValidSplit
from torch.optim import Adam
from braindecode.models.config import BraindecodeModelConfig
class TrainingConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="forbid")
infra: exca.TaskInfra = exca.TaskInfra(
folder=".cache/",
cluster=None, # local execution
)
model: BraindecodeModelConfig
train_dataset: DatasetConfig
max_epochs: int = 50
batch_size: int = 32
lr: float = 0.001
seed: int = 12
@infra.apply
def train(self) -> EEGClassifier:
# Load training data
train_set = self.train_dataset.create_instance()
train_y = train_set.get_metadata()["target"].to_numpy()
# Instantiate the model
model = self.model.create_instance()
clf = EEGClassifier(
model,
max_epochs=self.max_epochs,
batch_size=self.batch_size,
lr=self.lr,
train_split=ValidSplit(0.2, random_state=self.seed, stratified=True),
callbacks=["accuracy", EarlyStopping(patience=3)],
optimizer=Adam,
)
# Train the model
clf.fit(train_set, train_y)
return clf.module_.state_dict()
We note that the model has type braindecode.models.config.BraindecodeModelConfig. This type can match all the braindecode model configurations defined in braindecode.models.config.
We also see that there is now a cache folder specified (.cache/ here). This means that the results of the train() method will be cached on disk in this folder, instead of only in RAM.
Evaluation config#
Finally, we define an evaluation config that will load the validation data, load the trained model from the training config, and evaluate it on the validation data.
class EvaluationConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="forbid")
infra: exca.TaskInfra = exca.TaskInfra(
folder=".cache/",
cluster=None, # local execution
)
test_dataset: DatasetConfig
trainer: TrainingConfig
@infra.apply
def evaluate(self) -> float:
# Load validation data
valid_set = self.test_dataset.create_instance()
test_y = valid_set.get_metadata()["target"].to_numpy()
# Load trained model
state_dict = self.trainer.train()
model = self.trainer.model.create_instance()
model.load_state_dict(state_dict)
clf = EEGClassifier(model)
clf.initialize()
# Evaluate the model
score = clf.score(valid_set, test_y)
return score
Note
SLURM execution.
Exca also offers the possibility to run experiments remotely on a SLURM-managed cluster.
In this example, we run everything locally by setting cluster=None
but you can find more information about how to set up cluster execution
in the Exca documentation: https://facebookresearch.github.io/exca/infra/introduction.html.
Instantiating the configurations#
Instantiation option 1: from class constructors#
Now that our configuration classes are defined, we can instantiate them.
We will start with the model configuration.
Here, we use the braindecode.models.EEGNet model.
Like any other braindecode model, it has a corresponding configuration class in braindecode.models.config, called braindecode.models.config.EEGNetConfig.
We instantiate it using the signal properties we extracted earlier.
from braindecode.models.config import EEGConformerConfig, EEGNetConfig
signal_kwargs = {"n_times": 1000, "n_chans": 26, "n_outputs": 4}
model_cfg = EEGNetConfig(**signal_kwargs)
The config object can easily be serialized to a JSON format:
print(model_cfg.model_dump(mode="json"))
{'model_name_': 'EEGNet', 'n_chans': 26, 'n_outputs': 4, 'n_times': 1000, 'final_conv_length': 'auto', 'pool_mode': 'mean', 'F1': 8, 'D': 2, 'F2': None, 'kernel_length': 64, 'depthwise_kernel_length': 16, 'pool1_kernel_size': 4, 'pool2_kernel_size': 8, 'conv_spatial_max_norm': 1, 'activation': 'torch.nn.modules.activation.ELU', 'batch_norm_momentum': 0.01, 'batch_norm_affine': True, 'batch_norm_eps': 0.001, 'drop_prob': 0.25, 'final_layer_with_constraint': False, 'norm_rate': 0.25, 'chs_info': None, 'input_window_seconds': None, 'sfreq': None}
Alternatively, if you only want the non-default keys:
print(model_cfg.model_dump(exclude_defaults=True))
{'n_chans': 26, 'n_outputs': 4, 'n_times': 1000}
The config class is checking the arguments types and values, and
raises an error if something is wrong. For example, if we try to instantiate it using an incorrect type for n_times, we get an error:
# kept for restoration later:
true_n_times = signal_kwargs["n_times"]
# float instead of int:
signal_kwargs["n_times"] = 22.5
try:
EEGNetConfig(**signal_kwargs)
except pydantic.ValidationError as e:
print(f"Validation error raised as expected:\n{e}")
Validation error raised as expected:
1 validation error for EEGNetConfig
n_times
Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=22.5, input_type=float]
For further information visit https://errors.pydantic.dev/2.12/v/int_from_float
Similarly, if a mandatory argument is missing, we get an error:
del signal_kwargs["n_times"]
try:
EEGNetConfig(**signal_kwargs)
except pydantic.ValidationError as e:
print(f"Validation error raised as expected:\n{e}")
# We restore the correct value for ``n_times`` for the rest of the example:
signal_kwargs["n_times"] = true_n_times
Validation error raised as expected:
1 validation error for EEGNetConfig
Value error, n_times is required and could not be inferred.Either specify n_times or input_window_seconds and sfreq. [type=value_error, input_value={'n_chans': 26, 'n_outputs': 4}, input_type=dict]
For further information visit https://errors.pydantic.dev/2.12/v/value_error
We now have instantiated the model configuration. Creating the dataset, training and evaluation configurations is very similar and straightforward using the classes we defined earlier.
dataset_cfg = WindowedMOABBDatasetConfig(subject_id=1)
train_dataset_cfg = DatasetSplitConfig(dataset=dataset_cfg, key="0train")
test_dataset_cfg = DatasetSplitConfig(dataset=dataset_cfg, key="1test")
train_cfg = TrainingConfig(model=model_cfg, train_dataset=train_dataset_cfg)
eval_cfg = EvaluationConfig(trainer=train_cfg, test_dataset=test_dataset_cfg)
Instantiation option 2: from nested dictionaries or JSON files#
Alternatively, we can also instantiate the configurations from nested dictionaries or JSON files.
This can be useful when loading configurations from external sources.
Suppose we have the following JSON configuration for our evaluation.
We can load it as a nested dictionary using the json module:
import json
JSON_CFG = """{
"trainer": {
"model": {
"model_name_": "EEGNet",
"n_times": 1000,
"n_chans": 26,
"n_outputs": 4
},
"train_dataset": {
"dataset_type": "split",
"dataset": {"subject_id": 1},
"key": "0train"
}
},
"test_dataset": {
"dataset_type": "split",
"dataset": {"subject_id": 1},
"key": "1test"
}
}"""
NESTED_DICT_CFG = json.loads(JSON_CFG)
print(NESTED_DICT_CFG)
{'trainer': {'model': {'model_name_': 'EEGNet', 'n_times': 1000, 'n_chans': 26, 'n_outputs': 4}, 'train_dataset': {'dataset_type': 'split', 'dataset': {'subject_id': 1}, 'key': '0train'}}, 'test_dataset': {'dataset_type': 'split', 'dataset': {'subject_id': 1}, 'key': '1test'}}
We can instantiate the evaluation configuration from the nested dictionary
using the model_validate() method of Pydantic,
and check that it is identical to the one we created using the class constructors:
eval_cfg_from_dict = EvaluationConfig.model_validate(NESTED_DICT_CFG)
assert eval_cfg_from_dict == eval_cfg
Serializing the experiment configuration#
To serialize the experiment’s configuration, we can take advantage of Exca’s config() method, which is similar to Pydantic’s model_dump() method but will ensure that an experiment has a unique identifier (UID).
In particular, it will also include the "model_name_" field, which will allow us to distinguish between different model configurations later on.
print(eval_cfg.infra.config(uid=True, exclude_defaults=True))
{'test_dataset': {'dataset': {'subject_id': 1}, 'key': '1test', 'dataset_type': 'split'}, 'trainer': {'model': {'n_chans': 26, 'n_outputs': 4, 'n_times': 1000, 'model_name_': 'EEGNet'}, 'train_dataset': {'dataset': {'subject_id': 1}, 'key': '0train', 'dataset_type': 'split'}}}
Running the experiment#
Intermediate results are cached thanks to Exca#
We can now run the training using the configurations we defined.
For this, we simply have to call the train() method of the configuration.
we will time the execution to see the benefits of caching.
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.2522 1.3885 0.2414 0.2414 1.3864 0.7054
2 0.2522 1.3822 0.2414 0.2414 1.3864 0.6864
3 0.2522 1.3804 0.2414 0.2414 1.3864 0.6839
Stopping since valid_loss has not improved in the last 3 epochs.
Training took 4.95 seconds
If we call the train() method again, using the same configuration parameters, even if it is a new instance, the results will be loaded from the cache:
Rerunning training using cached results took 0.1795 seconds
We can run the evaluation in the same way, by calling the evaluate() method of the evaluation configuration.
Internally, this method calls the train() method of the training configuration, which will also use the cache if available.
Evaluation score: 0.25
Evaluation took 0.42 seconds
Scaling up: comparing multiple model configurations#
Now that we have seen how to define and run an experiment using Pydantic and Exca, we can easily scale up to compare multiple model configurations.
First, let’s define a small utility function to flatten nested dictionaries. This will help us later when we want to log results from different configurations. See in the example below, the keys of different levels are concatenated with a dot “.” separator.
def flatten_nested_dict(d, leaf_types=(int, float, str, bool), sep="."):
def aux(d, parent_key):
out = {}
for k, v in d.items():
if isinstance(v, dict):
out.update(aux(v, parent_key + k + sep))
elif isinstance(v, leaf_types):
out[parent_key + k] = v
return out
return aux(d, "")
flatten_nested_dict({"a": 1, "b": {"x": 1, "y": {"z": 2}}, "c": [4, 5]})
{'a': 1, 'b.x': 1, 'b.y.z': 2}
In a real experiment, we would launch all runs in parallel on a different nodes of a compute cluster. Please refer to the Exca documentation for more details on how to set up cluster execution. Here, for simplicity, we will just run them locally and sequentially.
In this mini-example, we will compare the EEGNet and EEGConformer models on the same dataset, with multiple random seeds.
model_cfg_list = [
EEGNetConfig(**signal_kwargs),
EEGConformerConfig(**signal_kwargs),
]
results = []
for model_cfg in model_cfg_list:
for seed in [1, 2, 3]:
train_cfg = TrainingConfig(
model=model_cfg,
train_dataset=train_dataset_cfg,
max_epochs=10,
lr=0.1,
seed=seed,
)
eval_cfg = EvaluationConfig(trainer=train_cfg, test_dataset=test_dataset_cfg)
# log configuration
row = flatten_nested_dict(
eval_cfg.infra.config(uid=True, exclude_defaults=True)
)
# evaluate and log accuracy:
row["accuracy"] = eval_cfg.evaluate()
results.append(row)
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.2478 3.1830 0.2586 0.2586 2.7699 0.8317
2 0.2522 3.1265 0.2414 0.2414 5.9990 0.6849
3 0.2522 2.5580 0.2414 0.2414 3.3327 0.6863
Stopping since valid_loss has not improved in the last 3 epochs.
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.2478 3.4241 0.2586 0.2586 10.4719 0.6863
2 0.2522 3.2196 0.2414 0.2414 43.7188 0.6871
3 0.2478 2.2397 0.2586 0.2586 26.4258 0.6866
Stopping since valid_loss has not improved in the last 3 epochs.
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.2478 3.4435 0.2586 0.2586 4.2782 0.6886
2 0.2522 2.5342 0.2414 0.2414 1.9056 0.6866
3 0.2478 1.9421 0.2586 0.2586 1.6714 0.6880
4 0.2478 1.6871 0.2586 0.2586 2.0010 0.6816
5 0.2478 1.3774 0.2586 0.2586 2.9307 0.6895
Stopping since valid_loss has not improved in the last 3 epochs.
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.2522 7837.7184 0.2414 0.2414 33044.5386 3.1137
2 0.2522 4397.0301 0.2414 0.2414 156564.7209 3.0734
3 0.2478 163.2298 0.2586 0.2586 2.7673 3.0684
4 0.2478 143.6231 0.2586 0.2586 2.1943 3.0576
5 0.2478 2.5060 0.2586 0.2586 1.7240 3.1515
6 0.2522 2.1955 0.2414 0.2414 1.4974 3.1057
7 0.2478 458.7910 0.2586 0.2586 1.4635 3.1380
8 0.2478 2742.4738 0.2586 0.2586 1.4702 3.0953
9 0.2478 7917.7216 0.2586 0.2586 1.5098 3.1435
Stopping since valid_loss has not improved in the last 3 epochs.
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.2478 12232.1578 0.2586 0.2586 13802.9802 3.0060
2 0.2522 2139.0917 0.2414 0.2414 4.9605 3.0351
3 0.2522 4.9978 0.2414 0.2414 4.8974 2.9969
4 0.2522 4.9193 0.2414 0.2414 4.5607 3.0109
5 0.2522 4.4134 0.2414 0.2414 4.1670 3.0273
6 0.2522 4.5307 0.2414 0.2414 3.8268 3.0282
7 0.2522 4.3506 0.2414 0.2414 3.5503 3.0150
8 0.2522 3.8486 0.2414 0.2414 3.2875 3.0109
9 0.2522 3.8191 0.2414 0.2414 3.0386 3.0120
10 0.2522 135.1870 0.2414 0.2414 2.7984 3.0064
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.2522 6004.3131 0.2414 0.2414 16271.5028 2.9908
2 0.2478 7747.1095 0.2586 0.2586 3.6436 3.0032
3 0.2478 3.7025 0.2586 0.2586 3.1256 3.0013
4 0.2522 3.7168 0.2414 0.2414 2.9819 3.0362
5 0.2522 3.7040 0.2414 0.2414 2.8630 3.0200
6 0.2478 3.4228 0.2586 0.2586 2.7126 3.0163
7 0.2478 3.3238 0.2586 0.2586 2.4859 3.0389
8 0.2478 3.0128 0.2586 0.2586 2.2308 3.0360
9 0.2478 3.0694 0.2586 0.2586 1.9870 3.0158
10 0.2478 2.7422 0.2586 0.2586 1.7605 3.0184
Gathering and displaying the results#
Loading results from cache#
If experiments were done on a cluster, a likely scenario would be to first run all experiments, and then later load and analyze the results.
Loading the results from cache is straightforward using Exca.
We simply need to re-instantiate the configurations with the same parameters,
and call the evaluate() method again.
The cached results will be loaded in a few seconds instead of re-running the experiments:
del results # oups, we forgot the results...
t0 = time.time()
results = []
for model_cfg in model_cfg_list:
for seed in [1, 2, 3]:
train_cfg = TrainingConfig(
model=model_cfg,
train_dataset=train_dataset_cfg,
max_epochs=10,
lr=0.1,
seed=seed,
)
eval_cfg = EvaluationConfig(trainer=train_cfg, test_dataset=test_dataset_cfg)
# log configuration
row = flatten_nested_dict(
eval_cfg.infra.config(uid=True, exclude_defaults=True)
)
# evaluate and log accuracy:
row["accuracy"] = eval_cfg.evaluate()
results.append(row)
t1 = time.time()
print(f"Loading all results from cache took {t1 - t0:0.2f} seconds")
Loading all results from cache took 1.46 seconds
Displaying the results#
Finally, we can concatenate and display the results using pandas:
import pandas as pd
results_df = pd.DataFrame(results)
print(results_df)
test_dataset.dataset.subject_id test_dataset.key ... trainer.seed accuracy
0 1 1test ... 1 0.25
1 1 1test ... 2 0.25
2 1 1test ... 3 0.25
3 1 1test ... 1 0.25
4 1 1test ... 2 0.25
5 1 1test ... 3 0.25
[6 rows x 14 columns]
Or first aggregated over seeds:
agg_results_df = results_df.groupby("trainer.model.model_name_").agg(
{"accuracy": ["mean", "std"]}
)
print(agg_results_df)
accuracy
mean std
trainer.model.model_name_
EEGConformer 0.25 0.0
EEGNet 0.25 0.0
Total running time of the script: (2 minutes 25.374 seconds)
Estimated memory usage: 2389 MB