.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_data_augmentation.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here <sphx_glr_download_auto_examples_plot_data_augmentation.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_data_augmentation.py: Data Augmentation on BCIC IV 2a Dataset ======================================= This tutorial shows how to train EEG deep models with data augmentation. It follows the trial-wise decoding example and also illustrates the effect of a transform on the input signals. .. contents:: This example covers: :local: :depth: 2 .. GENERATED FROM PYTHON SOURCE LINES 14-20 .. code-block:: default # Authors: Simon Brandt <simonbrandt@protonmail.com> # Cédric Rommel <cedric.rommel@inria.fr> # # License: BSD (3-clause) .. GENERATED FROM PYTHON SOURCE LINES 21-26 Loading and preprocessing the dataset ------------------------------------- Loading ~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 26-36 .. code-block:: default from skorch.helper import predefined_split from skorch.callbacks import LRScheduler from braindecode import EEGClassifier from braindecode.datasets import MOABBDataset subject_id = 3 dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id]) .. GENERATED FROM PYTHON SOURCE LINES 37-40 Preprocessing ~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 40-60 .. code-block:: default from braindecode.preprocessing import ( exponential_moving_standardize, preprocess, Preprocessor, scale) low_cut_hz = 4. # low cut frequency for filtering high_cut_hz = 38. # high cut frequency for filtering # Parameters for exponential moving standardization factor_new = 1e-3 init_block_size = 1000 preprocessors = [ Preprocessor('pick_types', eeg=True, meg=False, stim=False), # Keep EEG sensors Preprocessor(scale, factor=1e6, apply_on_array=True), # Convert from V to uV Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter Preprocessor(exponential_moving_standardize, # Exponential moving standardization factor_new=factor_new, init_block_size=init_block_size) ] preprocess(dataset, preprocessors) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/share/miniconda/envs/braindecode/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function scale is deprecated; will be removed in 0.7.0. Use numpy.multiply instead. warnings.warn(msg, category=FutureWarning) <braindecode.datasets.moabb.MOABBDataset object at 0x7f7490e7ced0> .. GENERATED FROM PYTHON SOURCE LINES 61-64 Extracting windows ~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 64-83 .. code-block:: default from braindecode.preprocessing import create_windows_from_events trial_start_offset_seconds = -0.5 # Extract sampling frequency, check that they are same in all datasets sfreq = dataset.datasets[0].raw.info['sfreq'] assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets]) # Calculate the trial start offset in samples. trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) # Create windows using braindecode function for this. It needs parameters to # define how trials should be used. windows_dataset = create_windows_from_events( dataset, trial_start_offset_samples=trial_start_offset_samples, trial_stop_offset_samples=0, preload=True, ) .. GENERATED FROM PYTHON SOURCE LINES 84-87 Split dataset into train and valid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 87-92 .. code-block:: default splitted = windows_dataset.split('session') train_set = splitted['session_T'] valid_set = splitted['session_E'] .. GENERATED FROM PYTHON SOURCE LINES 93-102 Defining a Transform -------------------- Data can be manipulated by transforms, which are callable objects. A transform is usually handled by a custom data loader, but can also be called directly on input data, as demonstrated below for illutrative purposes. First, we need to define a Transform. Here we chose the FrequencyShift, which randomly translates all frequencies within a given range. .. GENERATED FROM PYTHON SOURCE LINES 102-111 .. code-block:: default from braindecode.augmentation import FrequencyShift transform = FrequencyShift( probability=1., # defines the probability of actually modifying the input sfreq=sfreq, max_delta_freq=2. # the frequency shifts are sampled now between -2 and 2 Hz ) .. GENERATED FROM PYTHON SOURCE LINES 112-118 Manipulating one session and visualizing the transformed data ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Next, let us augment one session to show the resulting frequency shift. The data of an mne Epoch is used here to make usage of mne functions. .. GENERATED FROM PYTHON SOURCE LINES 118-127 .. code-block:: default import torch epochs = train_set.datasets[0].windows # original epochs X = epochs.get_data() # This allows to apply the transform with a fixed shift (10 Hz) for # visualization instead of sampling the shift randomly between -2 and 2 Hz X_tr, _ = transform.operation(torch.as_tensor(X).float(), None, 10., sfreq) .. GENERATED FROM PYTHON SOURCE LINES 128-130 The psd of the transformed session has now been shifted by 10 Hz, as one can see on the psd plot. .. GENERATED FROM PYTHON SOURCE LINES 130-153 .. code-block:: default import mne import matplotlib.pyplot as plt import numpy as np def plot_psd(data, axis, label, color): psds, freqs = mne.time_frequency.psd_array_multitaper(data, sfreq=sfreq, fmin=0.1, fmax=100) psds = 10. * np.log10(psds) psds_mean = psds.mean(0).mean(0) axis.plot(freqs, psds_mean, color=color, label=label) _, ax = plt.subplots() plot_psd(X, ax, 'original', 'k') plot_psd(X_tr.numpy(), ax, 'shifted', 'r') ax.set(title='Multitaper PSD (gradiometers)', xlabel='Frequency (Hz)', ylabel='Power Spectral Density (dB)') ax.legend() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_data_augmentation_001.png :alt: Multitaper PSD (gradiometers) :srcset: /auto_examples/images/sphx_glr_plot_data_augmentation_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 154-165 Training a model with data augmentation --------------------------------------- Now that we know how to instantiate ``Transforms``, it is time to learn how to use them to train a model and try to improve its generalization power. Let's first create a model. Create model ~~~~~~~~~~~~ The model to be trained is defined as usual. .. GENERATED FROM PYTHON SOURCE LINES 165-196 .. code-block:: default from braindecode.util import set_random_seeds from braindecode.models import ShallowFBCSPNet cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it device = 'cuda' if cuda else 'cpu' if cuda: torch.backends.cudnn.benchmark = True # Set random seed to be able to roughly reproduce results # Note that with cudnn benchmark set to True, GPU indeterminism # may still make results substantially different between runs. # To obtain more consistent results at the cost of increased computation time, # you can set `cudnn_benchmark=False` in `set_random_seeds` # or remove `torch.backends.cudnn.benchmark = True` seed = 20200220 set_random_seeds(seed=seed, cuda=cuda) n_classes = 4 # Extract number of chans and time steps from dataset n_channels = train_set[0][0].shape[0] input_window_samples = train_set[0][0].shape[1] model = ShallowFBCSPNet( n_channels, n_classes, input_window_samples=input_window_samples, final_conv_length='auto', ) .. GENERATED FROM PYTHON SOURCE LINES 197-203 Create an EEGClassifier with the desired augmentation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In order to train with data augmentation, a custom data loader can be for the training. Multiple transforms can be passed to it and will be applied sequentially to the batched data within the ``AugmentedDataLoader`` object. .. GENERATED FROM PYTHON SOURCE LINES 203-223 .. code-block:: default from braindecode.augmentation import AugmentedDataLoader, SignFlip freq_shift = FrequencyShift( probability=.5, sfreq=sfreq, max_delta_freq=2. # the frequency shifts are sampled now between -2 and 2 Hz ) sign_flip = SignFlip(probability=.1) transforms = [ freq_shift, sign_flip ] # Send model to GPU if cuda: model.cuda() .. GENERATED FROM PYTHON SOURCE LINES 224-227 The model is now trained as in the trial-wise example. The ``AugmentedDataLoader`` is used as the train iterator and the list of transforms are passed as arguments. .. GENERATED FROM PYTHON SOURCE LINES 227-254 .. code-block:: default lr = 0.0625 * 0.01 weight_decay = 0 batch_size = 64 n_epochs = 4 clf = EEGClassifier( model, iterator_train=AugmentedDataLoader, # This tells EEGClassifier to use a custom DataLoader iterator_train__transforms=transforms, # This sets the augmentations to use criterion=torch.nn.NLLLoss, optimizer=torch.optim.AdamW, train_split=predefined_split(valid_set), # using valid_set for validation optimizer__lr=lr, optimizer__weight_decay=weight_decay, batch_size=batch_size, callbacks=[ "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), ], device=device, ) # Model training for a specified number of epochs. `y` is None as it is already # supplied in the dataset. clf.fit(train_set, y=None, epochs=n_epochs) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none epoch train_accuracy train_loss valid_accuracy valid_loss lr dur ------- ---------------- ------------ ---------------- ------------ ------ ------ 1 0.2569 1.6225 0.2431 5.5634 0.0006 6.0187 2 0.2535 1.2218 0.2535 6.3996 0.0005 5.8492 3 0.2535 1.0973 0.2535 5.2946 0.0002 5.8053 4 0.2639 1.0922 0.2535 4.0276 0.0000 5.8025 <class 'braindecode.classifier.EEGClassifier'>[initialized]( module_=ShallowFBCSPNet( (ensuredims): Ensure4d() (dimshuffle): Expression(expression=transpose_time_to_spat) (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1)) (conv_spat): Conv2d(40, 40, kernel_size=(1, 22), stride=(1, 1), bias=False) (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv_nonlin_exp): Expression(expression=square) (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0) (pool_nonlin_exp): Expression(expression=safe_log) (drop): Dropout(p=0.5, inplace=False) (conv_classifier): Conv2d(40, 4, kernel_size=(69, 1), stride=(1, 1)) (softmax): LogSoftmax(dim=1) (squeeze): Expression(expression=squeeze_final_output) ), ) .. GENERATED FROM PYTHON SOURCE LINES 255-260 Manually composing Transforms ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It would be equivalent (although more verbose) to pass to ``EEGClassifier`` a composition of the same transforms: .. GENERATED FROM PYTHON SOURCE LINES 260-265 .. code-block:: default from braindecode.augmentation import Compose composed_transforms = Compose(transforms=transforms) .. GENERATED FROM PYTHON SOURCE LINES 266-274 Setting the data augmentation at the Dataset level ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Also note that it is also possible for most of the transforms to pass them directly to the WindowsDataset object through the `transform` argument, as most commonly done in other libraries. However, it is advised to use the ``AugmentedDataLoader`` as above, as it is compatible with all transforms and can be more efficient. .. GENERATED FROM PYTHON SOURCE LINES 274-276 .. code-block:: default train_set.transform = composed_transforms .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 38.415 seconds) **Estimated memory usage:** 1572 MB .. _sphx_glr_download_auto_examples_plot_data_augmentation.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_data_augmentation.py <plot_data_augmentation.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_data_augmentation.ipynb <plot_data_augmentation.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_