Note
Go to the end to download the full example code.
Data Augmentation on BCIC IV 2a Dataset#
This tutorial shows how to train EEG deep models with data augmentation. It follows the trial-wise decoding example and also illustrates the effect of a transform on the input signals.
# Authors: Simon Brandt <simonbrandt@protonmail.com>
# Cédric Rommel <cedric.rommel@inria.fr>
#
# License: BSD (3-clause)
Loading and preprocessing the dataset#
Loading#
from skorch.helper import predefined_split
from skorch.callbacks import LRScheduler
from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset
subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
BNCI2014001 has been renamed to BNCI2014_001. BNCI2014001 will be removed in version 1.1.
The dataset class name 'BNCI2014001' must be an abbreviation of its code 'BNCI2014-001'. See moabb.datasets.base.is_abbrev for more information.
Preprocessing#
from braindecode.preprocessing import (
exponential_moving_standardize,
preprocess,
Preprocessor,
)
from numpy import multiply
low_cut_hz = 4.0 # low cut frequency for filtering
high_cut_hz = 38.0 # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6
preprocessors = [
Preprocessor("pick_types", eeg=True, meg=False, stim=False), # Keep EEG sensors
Preprocessor(lambda data: multiply(data, factor)), # Convert from V to uV
Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter
Preprocessor(
exponential_moving_standardize, # Exponential moving standardization
factor_new=factor_new,
init_block_size=init_block_size,
),
]
preprocess(dataset, preprocessors, n_jobs=-1)
/home/runner/work/braindecode/braindecode/braindecode/preprocessing/preprocess.py:69: UserWarning: Preprocessing choices with lambda functions cannot be saved.
warn("Preprocessing choices with lambda functions cannot be saved.")
<braindecode.datasets.moabb.MOABBDataset object at 0x7fe7ea6df100>
Extracting windows#
from braindecode.preprocessing import create_windows_from_events
trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
# Create windows using braindecode function for this. It needs parameters to
# define how trials should be used.
windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
trial_stop_offset_samples=0,
preload=True,
)
Split dataset into train and valid#
Defining a Transform#
Data can be manipulated by transforms, which are callable objects. A transform is usually handled by a custom data loader, but can also be called directly on input data, as demonstrated below for illutrative purposes.
First, we need to define a Transform. Here we chose the FrequencyShift, which randomly translates all frequencies within a given range.
from braindecode.augmentation import FrequencyShift
transform = FrequencyShift(
probability=1.0, # defines the probability of actually modifying the input
sfreq=sfreq,
max_delta_freq=2.0, # the frequency shifts are sampled now between -2 and 2 Hz
)
Manipulating one session and visualizing the transformed data#
Next, let us augment one session to show the resulting frequency shift. The data of an mne Epoch is used here to make usage of mne functions.
import torch
import numpy as np
X = np.stack([X for X, y, i in train_set.datasets[0]])
# This allows to apply the transform with a fixed shift (10 Hz) for
# visualization instead of sampling the shift randomly between -2 and 2 Hz
X_tr, _ = transform.operation(torch.as_tensor(X).float(), None, 10.0, sfreq) # type: ignore[has-type]
The psd of the transformed session has now been shifted by 10 Hz, as one can see on the psd plot.
import mne
import matplotlib.pyplot as plt
def plot_psd(data, axis, label, color):
psds, freqs = mne.time_frequency.psd_array_multitaper(
data, sfreq=sfreq, fmin=0.1, fmax=100
)
psds = 10.0 * np.log10(psds)
psds_mean = psds.mean(0).mean(0)
axis.plot(freqs, psds_mean, color=color, label=label)
_, ax = plt.subplots()
plot_psd(X, ax, "original", "k")
plot_psd(X_tr.numpy(), ax, "shifted", "r")
ax.set(
title="Multitaper PSD (gradiometers)",
xlabel="Frequency (Hz)",
ylabel="Power Spectral Density (dB)",
)
ax.legend()
plt.show()
Training a model with data augmentation#
Now that we know how to instantiate Transforms
, it is time to learn how
to use them to train a model and try to improve its generalization power.
Let’s first create a model.
Create model#
The model to be trained is defined as usual.
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)
n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_channels = train_set[0][0].shape[0]
n_times = train_set[0][0].shape[1]
model = ShallowFBCSPNet(
n_chans=n_channels,
n_outputs=n_classes,
n_times=n_times,
final_conv_length="auto",
)
Create an EEGClassifier with the desired augmentation#
In order to train with data augmentation, a custom data loader can be
for the training. Multiple transforms can be passed to it and will be applied
sequentially to the batched data within the AugmentedDataLoader
object.
from braindecode.augmentation import AugmentedDataLoader, SignFlip
freq_shift = FrequencyShift(
probability=0.5,
sfreq=sfreq,
max_delta_freq=2.0, # the frequency shifts are sampled now between -2 and 2 Hz
)
sign_flip = SignFlip(probability=0.1)
transforms = [freq_shift, sign_flip]
# Send model to GPU
if cuda:
model.cuda()
The model is now trained as in the trial-wise example. The
AugmentedDataLoader
is used as the train iterator and the list of
transforms are passed as arguments.
lr = 0.0625 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 4
clf = EEGClassifier(
model,
iterator_train=AugmentedDataLoader, # This tells EEGClassifier to use a custom DataLoader
iterator_train__transforms=transforms, # This sets the augmentations to use
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.AdamW,
train_split=predefined_split(valid_set), # using valid_set for validation
optimizer__lr=lr,
optimizer__weight_decay=weight_decay,
batch_size=batch_size,
callbacks=[
"accuracy",
("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
],
device=device,
classes=classes,
)
# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss lr dur
------- ---------------- ------------ ----------- ---------------- ------------ ------ ------
1 0.2639 1.4655 0.2639 0.2639 1.5266 0.0006 1.7495
2 0.3299 1.3119 0.3194 0.3194 1.3948 0.0005 1.5773
3 0.4757 1.1941 0.2986 0.2986 1.3259 0.0002 1.5847
4 0.5625 1.1671 0.3333 0.3333 1.3025 0.0000 1.5876
<class 'braindecode.classifier.EEGClassifier'>[initialized](
module_=============================================================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape Param # Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet) [1, 22, 1125] [1, 4] -- --
├─SafeLog (pool_nonlin_exp): 1-1 [1, 22, 1125] [1, 22, 1125] -- --
├─Ensure4d (ensuredims): 1-2 [1, 22, 1125] [1, 22, 1125, 1] -- --
├─Rearrange (dimshuffle): 1-3 [1, 22, 1125, 1] [1, 1, 1125, 22] -- --
├─CombinedConv (conv_time_spat): 1-4 [1, 1, 1125, 22] [1, 40, 1101, 1] 36,240 --
├─BatchNorm2d (bnorm): 1-5 [1, 40, 1101, 1] [1, 40, 1101, 1] 80 --
├─Expression (conv_nonlin_exp): 1-6 [1, 40, 1101, 1] [1, 40, 1101, 1] -- --
├─AvgPool2d (pool): 1-7 [1, 40, 1101, 1] [1, 40, 69, 1] -- [75, 1]
├─SafeLog (pool_nonlin_exp): 1-8 [1, 40, 69, 1] [1, 40, 69, 1] -- --
├─Dropout (drop): 1-9 [1, 40, 69, 1] [1, 40, 69, 1] -- --
├─Sequential (final_layer): 1-10 [1, 40, 69, 1] [1, 4] -- --
│ └─Conv2d (conv_classifier): 2-1 [1, 40, 69, 1] [1, 4, 1, 1] 11,044 [69, 1]
│ └─Expression (squeeze): 2-2 [1, 4, 1, 1] [1, 4] -- --
============================================================================================================================================
Total params: 47,364
Trainable params: 47,364
Non-trainable params: 0
Total mult-adds (M): 0.01
============================================================================================================================================
Input size (MB): 0.10
Forward/backward pass size (MB): 0.35
Params size (MB): 0.04
Estimated Total Size (MB): 0.50
============================================================================================================================================,
)
Manually composing Transforms#
It would be equivalent (although more verbose) to pass to EEGClassifier
a
composition of the same transforms:
from braindecode.augmentation import Compose
composed_transforms = Compose(transforms=transforms)
Setting the data augmentation at the Dataset level#
Also note that it is also possible for most of the transforms to pass them
directly to the WindowsDataset object through the transform argument, as
most commonly done in other libraries. However, it is advised to use the
AugmentedDataLoader
as above, as it is compatible with all transforms and
can be more efficient.
Total running time of the script: (0 minutes 16.960 seconds)
Estimated memory usage: 1306 MB