Cleaning EEG Data with EEGPrep for Trialwise Decoding#

This is a variant of the basic Trialwise decoding tutorial decoding example that additionally inserts an EEGPrep stage into the preprocessing pipeline as a minimal demonstration of how to use EEGPrep with Braindecode.

Loading and preparing the data#

Loading the dataset#

First, we load the data. In this tutorial, we load the BCI Competition IV 2a data [1] using braindecode’s wrapper to load via MOABB library [2].

Note

To load your own datasets either via mne or from preprocessed X/y numpy arrays, see MNE Dataset Example and Custom Dataset Example.

from braindecode.datasets import MOABBDataset

subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])
/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/site-packages/moabb/datasets/download.py:97: RuntimeWarning: Setting non-standard config type: "MNE_DATASETS_BNCI_PATH"
  set_config(key, get_config("MNE_DATA"))

Preprocessing#

Now we apply a series of preprocessing steps to our dataset.

The conventional approach in deep learning is to keep preprocessing minimal and leave it to the model to learn relevant features, as done in the seminal early use of deep learning on EEG in [3] and many subsequent works.

However, since EEG can contain quite dramatic artifacts that can easily dwarf the signal of interest and which may harm learning or throw off predictions, additional artifact removal steps can be beneficial in conjunction with deep models. The following code starts from the minimal preprocessing pipeline in Trialwise decoding tutorial and inserts the EEGPrep Preprocessor into the pipeline. This is an integration with the eegprep preprocessing library that implements a series of automated artifact removal steps first proposed in [4] and later refined as part of the (now-default) raw-data preprocessing approach in EEGLAB [5].

The EEGPrep class represents the default end-to-end preprocessing pipeline, which has only a few primary parameters that are worth tuning for a given dataset, the most important ones of which are shown in the code below.

Besides using the end-to-end pipeline as a whole, users can also separately invoke the individual preprocessing steps implemented in EEGPrep as needed; for additional details see the documentation for EEGPrep.

Note

EEGPrep is best used early in the preprocessing pipeline, when you are still acting on continuous (raw) data. The nature of the data after processing is essentially the same as the input (minus many of the artifacts), so you can typically retain most other processing steps that your pipeline would otherwise use, as below.

from numpy import multiply

from braindecode.preprocessing import (
    EEGPrep,
    Preprocessor,
    exponential_moving_standardize,
    preprocess,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
    # If you have non-EEG channels in the data that you do not want to keep,
    # it is best to remove them early on, which is more memory-efficient.
    # EEGPrep generally only acts on the EEG channels.
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),
    # This particular dataset requires a conversion from V to uV; this
    # could also be done later in the pipeline since EEGPrep does not
    # care about absolute scaling
    Preprocessor(lambda data: multiply(data, factor)),
    # Here we insert the EEGPrep preprocessing step; experiment with commenting
    # this out to see how it affects results. You can also disable additional
    # processing steps in the pipeline by setting select parameters to None.
    EEGPrep(
        resample_to=128,
        # This is best disabled for single-trial classification (see EEGPrep docs)
        bad_window_max_bad_channels=None,
        # The following examples show some other frequently used non-default values:
        # burst_removal_cutoff=15.0,       # 15.0 -> less aggressive burst removal
        # bad_channel_corr_threshold=0.75, # 0.75 -> less aggressive channel removal
    ),
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Transform the data
preprocess(dataset, preprocessors, n_jobs=-1)
/home/runner/work/braindecode/braindecode/braindecode/preprocessing/preprocess.py:77: UserWarning: apply_on_array can only be True if fn is a callable function. Automatically correcting to apply_on_array=False.
  warn(
/home/runner/work/braindecode/braindecode/braindecode/preprocessing/preprocess.py:75: UserWarning: Preprocessing choices with lambda functions cannot be saved.
  warn("Preprocessing choices with lambda functions cannot be saved.")
BaseConcatDataset
TypeBaseConcatDataset of RawDataset
Recordings12
Total samples594348
Sfreq*128.0 Hz
Channels*22 (22 EEG)
Ch. names*Fz, FC3, FC1, FCz, FC2, FC4, C5, C3, C1, Cz, ... (+12 more)
Montage*head
Duration*386.9 s
* from first recording
Description12 recordings × 3 columns [subject, session, run]


Besides using the end-to-end pipeline as a whole, you can also separately invoke the individual preprocessing steps implemented in EEGPrep as needed; see the EEGPrep class documentation for details.

Note

When using individual artifact removal steps, make sure they are applied in the intended order, since otherwise you may get suboptimal results.

Extracting Compute Windows#

Now we extract compute windows from the signals, these will be the inputs to the deep networks during training. In the case of trialwise decoding, we just have to decide if we want to include some part before and/or after the trial. For our work with this dataset, it was often beneficial to also include the 500 ms before the trial.

from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)

Splitting the dataset into training and validation sets#

We can easily split the dataset using additional info stored in the description attribute, in this case session column. We select 0train for training and 1test for validation.

splitted = windows_dataset.split("session")
train_set = splitted["0train"]  # Session train
valid_set = splitted["1test"]  # Session evaluation

Creating a model#

Now we create the deep learning model! Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the EEGNeX model from [6]. These models are pure PyTorch deep learning models, therefore to use your own model, it just has to be a normal PyTorch torch.nn.Module.

import torch

from braindecode.models import EEGNeX
from braindecode.util import set_random_seeds

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_chans = train_set[0][0].shape[0]
n_times = train_set[0][0].shape[1]

# EEGNeX is a compact convolutional architecture with a deeper temporal stack
# than EEGNet, but the optimizer settings below are still just a simple
# starting point for this tutorial.
model = EEGNeX(
    n_chans=n_chans,
    n_outputs=n_classes,
    n_times=n_times,
)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    model = model.cuda()
/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/site-packages/torch/nn/modules/conv.py:548: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at /pytorch/aten/src/ATen/native/Convolution.cpp:1024.)
  return F.conv2d(
================================================================================================================================================================
Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #                   Kernel Shape
================================================================================================================================================================
EEGNeX (EEGNeX)                                              [1, 22, 576]              [1, 4]                    --                        --
├─Sequential (block_1): 1-1                                  [1, 22, 576]              [1, 8, 22, 576]           --                        --
│    └─Rearrange (0): 2-1                                    [1, 22, 576]              [1, 1, 22, 576]           --                        --
│    └─Conv2d (1): 2-2                                       [1, 1, 22, 576]           [1, 8, 22, 576]           512                       [1, 64]
│    └─BatchNorm2d (2): 2-3                                  [1, 8, 22, 576]           [1, 8, 22, 576]           16                        --
├─Sequential (block_2): 1-2                                  [1, 8, 22, 576]           [1, 32, 22, 576]          --                        --
│    └─Conv2d (0): 2-4                                       [1, 8, 22, 576]           [1, 32, 22, 576]          16,384                    [1, 64]
│    └─BatchNorm2d (1): 2-5                                  [1, 32, 22, 576]          [1, 32, 22, 576]          64                        --
├─Sequential (block_3): 1-3                                  [1, 32, 22, 576]          [1, 64, 1, 144]           --                        --
│    └─ParametrizedConv2dWithConstraint (0): 2-6             [1, 32, 22, 576]          [1, 64, 1, 576]           --                        [22, 1]
│    │    └─ModuleDict (parametrizations): 3-1               --                        --                        1,408                     --
│    └─BatchNorm2d (1): 2-7                                  [1, 64, 1, 576]           [1, 64, 1, 576]           128                       --
│    └─ELU (2): 2-8                                          [1, 64, 1, 576]           [1, 64, 1, 576]           --                        --
│    └─AvgPool2d (3): 2-9                                    [1, 64, 1, 576]           [1, 64, 1, 144]           --                        [1, 4]
│    └─Dropout (4): 2-10                                     [1, 64, 1, 144]           [1, 64, 1, 144]           --                        --
├─Sequential (block_4): 1-4                                  [1, 64, 1, 144]           [1, 32, 1, 144]           --                        --
│    └─Conv2d (0): 2-11                                      [1, 64, 1, 144]           [1, 32, 1, 144]           32,768                    [1, 16]
│    └─BatchNorm2d (1): 2-12                                 [1, 32, 1, 144]           [1, 32, 1, 144]           64                        --
├─Sequential (block_5): 1-5                                  [1, 32, 1, 144]           [1, 144]                  --                        --
│    └─Conv2d (0): 2-13                                      [1, 32, 1, 144]           [1, 8, 1, 144]            4,096                     [1, 16]
│    └─BatchNorm2d (1): 2-14                                 [1, 8, 1, 144]            [1, 8, 1, 144]            16                        --
│    └─ELU (2): 2-15                                         [1, 8, 1, 144]            [1, 8, 1, 144]            --                        --
│    └─AvgPool2d (3): 2-16                                   [1, 8, 1, 144]            [1, 8, 1, 18]             --                        [1, 8]
│    └─Dropout (4): 2-17                                     [1, 8, 1, 18]             [1, 8, 1, 18]             --                        --
│    └─Flatten (5): 2-18                                     [1, 8, 1, 18]             [1, 144]                  --                        --
├─ParametrizedLinearWithConstraint (final_layer): 1-6        [1, 144]                  [1, 4]                    4                         --
│    └─ModuleDict (parametrizations): 2-19                   --                        --                        --                        --
│    │    └─ParametrizationList (weight): 3-2                --                        [4, 144]                  576                       --
================================================================================================================================================================
Total params: 56,036
Trainable params: 56,036
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 219.41
================================================================================================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 8.50
Params size (MB): 0.22
Estimated Total Size (MB): 8.76
================================================================================================================================================================

Model Training#

Now we will train the network! EEGClassifier is a Braindecode object responsible for managing the training of neural networks. It inherits from skorch.classifier.NeuralNetClassifier, so the training logic is the same as in skorch.

Note

In this tutorial, we use some default parameters that we have found to work well for motor decoding, however we strongly encourage you to perform your own hyperparameter optimization using cross validation on your training data.

from skorch.callbacks import EarlyStopping, LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGClassifier

# A moderate AdamW learning rate is a reasonable starting point for EEGNeX:
lr = 1e-3
weight_decay = 0

batch_size = 64
n_epochs = 4

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),  # using valid_set for validation
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    batch_size=batch_size,
    callbacks=[
        "accuracy",
        ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=max(1, n_epochs - 1))),
        ("early_stopping", EarlyStopping(patience=10, load_best=True)),
    ],
    device=device,
    classes=classes,
)
# Model training for the specified number of epochs. ``y`` is ``None`` as it is
# already supplied in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2500        1.3873       0.2500            0.2500        1.3868  0.0010  7.1591
      2            0.2500        1.3542       0.2500            0.2500        1.3861  0.0008  7.1164
      3            0.2812        1.3272       0.2674            0.2674        1.3859  0.0003  7.1249
      4            0.3576        1.3164       0.3090            0.3090        1.3856  0.0000  7.0985
<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=================================================================================================================================================================
  Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #                   Kernel Shape
  ================================================================================================================================================================
  EEGNeX (EEGNeX)                                              [1, 22, 576]              [1, 4]                    --                        --
  ├─Sequential (block_1): 1-1                                  [1, 22, 576]              [1, 8, 22, 576]           --                        --
  │    └─Rearrange (0): 2-1                                    [1, 22, 576]              [1, 1, 22, 576]           --                        --
  │    └─Conv2d (1): 2-2                                       [1, 1, 22, 576]           [1, 8, 22, 576]           512                       [1, 64]
  │    └─BatchNorm2d (2): 2-3                                  [1, 8, 22, 576]           [1, 8, 22, 576]           16                        --
  ├─Sequential (block_2): 1-2                                  [1, 8, 22, 576]           [1, 32, 22, 576]          --                        --
  │    └─Conv2d (0): 2-4                                       [1, 8, 22, 576]           [1, 32, 22, 576]          16,384                    [1, 64]
  │    └─BatchNorm2d (1): 2-5                                  [1, 32, 22, 576]          [1, 32, 22, 576]          64                        --
  ├─Sequential (block_3): 1-3                                  [1, 32, 22, 576]          [1, 64, 1, 144]           --                        --
  │    └─ParametrizedConv2dWithConstraint (0): 2-6             [1, 32, 22, 576]          [1, 64, 1, 576]           --                        [22, 1]
  │    │    └─ModuleDict (parametrizations): 3-1               --                        --                        1,408                     --
  │    └─BatchNorm2d (1): 2-7                                  [1, 64, 1, 576]           [1, 64, 1, 576]           128                       --
  │    └─ELU (2): 2-8                                          [1, 64, 1, 576]           [1, 64, 1, 576]           --                        --
  │    └─AvgPool2d (3): 2-9                                    [1, 64, 1, 576]           [1, 64, 1, 144]           --                        [1, 4]
  │    └─Dropout (4): 2-10                                     [1, 64, 1, 144]           [1, 64, 1, 144]           --                        --
  ├─Sequential (block_4): 1-4                                  [1, 64, 1, 144]           [1, 32, 1, 144]           --                        --
  │    └─Conv2d (0): 2-11                                      [1, 64, 1, 144]           [1, 32, 1, 144]           32,768                    [1, 16]
  │    └─BatchNorm2d (1): 2-12                                 [1, 32, 1, 144]           [1, 32, 1, 144]           64                        --
  ├─Sequential (block_5): 1-5                                  [1, 32, 1, 144]           [1, 144]                  --                        --
  │    └─Conv2d (0): 2-13                                      [1, 32, 1, 144]           [1, 8, 1, 144]            4,096                     [1, 16]
  │    └─BatchNorm2d (1): 2-14                                 [1, 8, 1, 144]            [1, 8, 1, 144]            16                        --
  │    └─ELU (2): 2-15                                         [1, 8, 1, 144]            [1, 8, 1, 144]            --                        --
  │    └─AvgPool2d (3): 2-16                                   [1, 8, 1, 144]            [1, 8, 1, 18]             --                        [1, 8]
  │    └─Dropout (4): 2-17                                     [1, 8, 1, 18]             [1, 8, 1, 18]             --                        --
  │    └─Flatten (5): 2-18                                     [1, 8, 1, 18]             [1, 144]                  --                        --
  ├─ParametrizedLinearWithConstraint (final_layer): 1-6        [1, 144]                  [1, 4]                    4                         --
  │    └─ModuleDict (parametrizations): 2-19                   --                        --                        --                        --
  │    │    └─ParametrizationList (weight): 3-2                --                        [4, 144]                  576                       --
  ================================================================================================================================================================
  Total params: 56,036
  Trainable params: 56,036
  Non-trainable params: 0
  Total mult-adds (Units.MEGABYTES): 219.41
  ================================================================================================================================================================
  Input size (MB): 0.05
  Forward/backward pass size (MB): 8.50
  Params size (MB): 0.22
  Estimated Total Size (MB): 8.76
  ================================================================================================================================================================,
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Training for longer#

The gallery build above uses only n_epochs = 4. When trained offline for up to 100 epochs with early stopping, the model reaches 81.9 % accuracy on the held-out session (chance = 25 %).

We can load the pretrained checkpoint from the Hugging Face Hub and inspect the full training curves:

import warnings

repo_id = "braindecode/plot_bcic_iv_2a_eegprep_cleaning"
try:
    from huggingface_hub import hf_hub_download

    clf.initialize()
    clf.load_params(
        f_params=hf_hub_download(repo_id, "params.safetensors"),
        f_history=hf_hub_download(repo_id, "history.json"),
        use_safetensors=True,
    )
except Exception as exc:
    warnings.warn(
        f"Could not load pretrained checkpoint from {repo_id} ({exc}); "
        "continuing with the locally trained short-run model.",
        stacklevel=2,
    )
Re-initializing module.
Re-initializing criterion.
Re-initializing optimizer.

Plot training curves#

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D

# Extract loss and accuracy values for plotting from history object
results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"]
df = pd.DataFrame(
    clf.history[:, results_columns],
    columns=results_columns,
    index=clf.history[:, "epoch"],
)

# get percent of misclass for better visual comparison to loss
df = df.assign(
    train_misclass=100 - 100 * df.train_accuracy,
    valid_misclass=100 - 100 * df.valid_accuracy,
)

fig, ax1 = plt.subplots(figsize=(8, 3))
df.loc[:, ["train_loss", "valid_loss"]].plot(
    ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False, fontsize=14
)

ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=14)
ax1.set_ylabel("Loss", color="tab:blue", fontsize=14)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

df.loc[:, ["train_misclass", "valid_misclass"]].plot(
    ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False
)
ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=14)
ax2.set_ylabel("Misclassification Rate [%]", color="tab:red", fontsize=14)
ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
ax1.set_xlabel("Epoch", fontsize=14)

# where some data has already been plotted to ax
handles = []
handles.append(
    Line2D([0], [0], color="black", linewidth=1, linestyle="-", label="Train")
)
handles.append(
    Line2D([0], [0], color="black", linewidth=1, linestyle=":", label="Valid")
)
plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
plt.tight_layout()
plot bcic iv 2a eegprep cleaning

Plotting a Confusion Matrix#

Here we generate a confusion matrix as in [3].

from sklearn.metrics import ConfusionMatrixDisplay

y_true = valid_set.get_metadata().target
y_pred = clf.predict(valid_set)

label_dict = windows_dataset.datasets[0].window_kwargs[0][1]["mapping"]
sorted_items = sorted(label_dict.items(), key=lambda kv: kv[1])
labels = [k for k, _ in sorted_items]
class_ids = [v for _, v in sorted_items]

ConfusionMatrixDisplay.from_predictions(
    y_true, y_pred, labels=class_ids, display_labels=labels
)
plot bcic iv 2a eegprep cleaning
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay object at 0x7f7f61984fe0>

References#

Total running time of the script: (1 minutes 17.522 seconds)

Estimated memory usage: 1516 MB

Gallery generated by Sphinx-Gallery