Data Augmentation on BCIC IV 2a Dataset

This tutorial shows how to train EEG deep models with data augmentation. It follows the trial-wise decoding example and also illustrates the effect of a transform on the input signals.

# Authors: Simon Brandt <>
#          Cédric Rommel <>
# License: BSD (3-clause)

Loading and preprocessing the dataset


from skorch.helper import predefined_split
from skorch.callbacks import LRScheduler

from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset

subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])


from braindecode.preprocessing import (
    exponential_moving_standardize, preprocess, Preprocessor, scale)

low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(scale, factor=1e6, apply_on_array=True),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)

preprocess(dataset, preprocessors)


/usr/share/miniconda/envs/braindecode/lib/python3.7/site-packages/sklearn/utils/ FutureWarning: Function scale is deprecated; will be removed in 0.7.0. Use numpy.multiply instead.
  warnings.warn(msg, category=FutureWarning)

<braindecode.datasets.moabb.MOABBDataset object at 0x7f7490e7ced0>

Extracting windows

from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0]['sfreq']
assert all([['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to
# define how trials should be used.
windows_dataset = create_windows_from_events(

Split dataset into train and valid

splitted = windows_dataset.split('session')
train_set = splitted['session_T']
valid_set = splitted['session_E']

Defining a Transform

Data can be manipulated by transforms, which are callable objects. A transform is usually handled by a custom data loader, but can also be called directly on input data, as demonstrated below for illutrative purposes.

First, we need to define a Transform. Here we chose the FrequencyShift, which randomly translates all frequencies within a given range.

from braindecode.augmentation import FrequencyShift

transform = FrequencyShift(
    probability=1.,  # defines the probability of actually modifying the input
    max_delta_freq=2.  # the frequency shifts are sampled now between -2 and 2 Hz

Manipulating one session and visualizing the transformed data

Next, let us augment one session to show the resulting frequency shift. The data of an mne Epoch is used here to make usage of mne functions.

import torch

epochs = train_set.datasets[0].windows  # original epochs
X = epochs.get_data()
# This allows to apply the transform with a fixed shift (10 Hz) for
# visualization instead of sampling the shift randomly between -2 and 2 Hz
X_tr, _ = transform.operation(torch.as_tensor(X).float(), None, 10., sfreq)

The psd of the transformed session has now been shifted by 10 Hz, as one can see on the psd plot.

import mne
import matplotlib.pyplot as plt
import numpy as np

def plot_psd(data, axis, label, color):
    psds, freqs = mne.time_frequency.psd_array_multitaper(data, sfreq=sfreq,
                                                          fmin=0.1, fmax=100)
    psds = 10. * np.log10(psds)
    psds_mean = psds.mean(0).mean(0)
    axis.plot(freqs, psds_mean, color=color, label=label)

_, ax = plt.subplots()
plot_psd(X, ax, 'original', 'k')
plot_psd(X_tr.numpy(), ax, 'shifted', 'r')

ax.set(title='Multitaper PSD (gradiometers)', xlabel='Frequency (Hz)',
       ylabel='Power Spectral Density (dB)')
Multitaper PSD (gradiometers)

Training a model with data augmentation

Now that we know how to instantiate Transforms, it is time to learn how to use them to train a model and try to improve its generalization power. Let’s first create a model.

Create model

The model to be trained is defined as usual.

from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True

# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4

# Extract number of chans and time steps from dataset
n_channels = train_set[0][0].shape[0]
input_window_samples = train_set[0][0].shape[1]

model = ShallowFBCSPNet(

Create an EEGClassifier with the desired augmentation

In order to train with data augmentation, a custom data loader can be for the training. Multiple transforms can be passed to it and will be applied sequentially to the batched data within the AugmentedDataLoader object.

from braindecode.augmentation import AugmentedDataLoader, SignFlip

freq_shift = FrequencyShift(
    max_delta_freq=2.  # the frequency shifts are sampled now between -2 and 2 Hz

sign_flip = SignFlip(probability=.1)

transforms = [

# Send model to GPU
if cuda:

The model is now trained as in the trial-wise example. The AugmentedDataLoader is used as the train iterator and the list of transforms are passed as arguments.

lr = 0.0625 * 0.01
weight_decay = 0

batch_size = 64
n_epochs = 4

clf = EEGClassifier(
    iterator_train=AugmentedDataLoader,  # This tells EEGClassifier to use a custom DataLoader
    iterator_train__transforms=transforms,  # This sets the augmentations to use
    train_split=predefined_split(valid_set),  # using valid_set for validation
        ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset., y=None, epochs=n_epochs)


  epoch    train_accuracy    train_loss    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  ----------------  ------------  ------  ------
      1            0.2569        1.6225            0.2431        5.5634  0.0006  6.0187
      2            0.2535        1.2218            0.2535        6.3996  0.0005  5.8492
      3            0.2535        1.0973            0.2535        5.2946  0.0002  5.8053
      4            0.2639        1.0922            0.2535        4.0276  0.0000  5.8025

<class 'braindecode.classifier.EEGClassifier'>[initialized](
    (ensuredims): Ensure4d()
    (dimshuffle): Expression(expression=transpose_time_to_spat)
    (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
    (conv_spat): Conv2d(40, 40, kernel_size=(1, 22), stride=(1, 1), bias=False)
    (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin_exp): Expression(expression=square)
    (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)
    (pool_nonlin_exp): Expression(expression=safe_log)
    (drop): Dropout(p=0.5, inplace=False)
    (conv_classifier): Conv2d(40, 4, kernel_size=(69, 1), stride=(1, 1))
    (softmax): LogSoftmax(dim=1)
    (squeeze): Expression(expression=squeeze_final_output)

Manually composing Transforms

It would be equivalent (although more verbose) to pass to EEGClassifier a composition of the same transforms:

from braindecode.augmentation import Compose

composed_transforms = Compose(transforms=transforms)

Setting the data augmentation at the Dataset level

Also note that it is also possible for most of the transforms to pass them directly to the WindowsDataset object through the transform argument, as most commonly done in other libraries. However, it is advised to use the AugmentedDataLoader as above, as it is compatible with all transforms and can be more efficient.

Total running time of the script: ( 0 minutes 38.415 seconds)

Estimated memory usage: 1572 MB

Gallery generated by Sphinx-Gallery