Simple training on MNE epochs#

The braindecode library gives you access to a large number of neural network architectures that were developed for EEG data decoding. This tutorial will show you how you can easily use any of these models to decode your own data. In particular, we assume that have your data in an MNE format and want to train one of the Braindecode models on it.

# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)

Finding the model you want#

Exploring the braindecode online documentation#

Let’s suppose you recently stumbled upon the Schirrmeister 2017 article [1]. In this article, the authors mention that their novel architecture ShallowConvNet is performing well on the BCI Competition IV 2a dataset and you would like to use it on your own data. Fortunately, the authors also mentioned they published their architecture on Braindecode!

In order to use this architecture, you first need to find what is its exact name in Braindecode. To do so, you can visit the Braindecode online documentation which lists all the available models.

Models list: https://braindecode.org/stable/api.html#models

Alternatively, the API also provide a dictionary with all available models:

from braindecode.models.util import models_dict

print(f"All the Braindecode models:\n{list(models_dict.keys())}")
All the Braindecode models:
['ATCNet', 'AttentionBaseNet', 'BDTCN', 'BIOT', 'CTNet', 'ContraWR', 'Deep4Net', 'DeepSleepNet', 'EEGConformer', 'EEGITNet', 'EEGInceptionERP', 'EEGInceptionMI', 'EEGMiner', 'EEGNeX', 'EEGNetv1', 'EEGNetv4', 'EEGResNet', 'EEGSimpleConv', 'EEGTCNet', 'Labram', 'MSVTNet', 'SCCNet', 'SPARCNet', 'ShallowFBCSPNet', 'SincShallowNet', 'SleepStagerBlanco2020', 'SleepStagerChambon2018', 'SleepStagerEldele2021', 'SyncNet', 'TIDNet', 'TSceptionV1', 'USleep']

After your investigation, you found out that the model you are looking for is ShallowFBCSPNet. You can now import it from Braindecode:

from braindecode.models import ShallowFBCSPNet

Examining the model#

Now that you found your model, you must check which parameters it expects. You can find this information either in the online documentation here: braindecode.models.ShallowFBCSPNet or directly in the module’s docstring:

print(ShallowFBCSPNet.__doc__)
Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.

.. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
    :align: center
    :alt: ShallowNet Architecture

Model described in [Schirrmeister2017]_.

Parameters
----------
n_chans : int
    Number of EEG channels.
n_outputs : int
    Number of outputs of the model. This is the number of classes
    in the case of classification.
n_times : int
    Number of time samples of the input window.
n_filters_time: int
    Number of temporal filters.
filter_time_length: int
    Length of the temporal filter.
n_filters_spat: int
    Number of spatial filters.
pool_time_length: int
    Length of temporal pooling filter.
pool_time_stride: int
    Length of stride between temporal pooling filters.
final_conv_length: int | str
    Length of the final convolution layer.
    If set to "auto", length of the input signal must be specified.
conv_nonlin: callable
    Non-linear function to be used after convolution layers.
pool_mode: str
    Method to use on pooling layers. "max" or "mean".
activation_pool_nonlin: callable
    Non-linear function to be used after pooling layers.
split_first_layer: bool
    Split first layer into temporal and spatial layers (True) or just use temporal (False).
    There would be no non-linearity between the split layers.
batch_norm: bool
    Whether to use batch normalisation.
batch_norm_alpha: float
    Momentum for BatchNorm2d.
drop_prob: float
    Dropout probability.
chs_info : list of dict
    Information about each individual EEG channel. This should be filled with
    ``info["chs"]``. Refer to :class:`mne.Info` for more details.
input_window_seconds : float
    Length of the input window in seconds.
sfreq : float
    Sampling frequency of the EEG recordings.

Raises
------
ValueError: If some input signal-related parameters are not specified
            and can not be inferred.

FutureWarning: If add_log_softmax is True, since LogSoftmax final layer
               will be removed in the future.

Notes
-----
If some input signal-related parameters are not specified,
there will be an attempt to infer them from the other parameters.

References
----------
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
   L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
   & Ball, T. (2017).
   Deep learning with convolutional neural networks for EEG decoding and
   visualization.
   Human Brain Mapping , Aug. 2017.
   Online: http://dx.doi.org/10.1002/hbm.23730

Additionally, you might be interested in visualizing the model’s architecture. This can be done by initializing the model and calling its __str__() method. To initialize it, we need to specify some parameters that we set at random values for now:

model = ShallowFBCSPNet(
    n_chans=32,
    n_times=1000,
    n_outputs=2,
    final_conv_length="auto",
)
print(model)
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 32, 1000]             [1, 2]                    --                        --
├─SafeLog (pool_nonlin_exp): 1-1         [1, 32, 1000]             [1, 32, 1000]             --                        --
├─Ensure4d (ensuredims): 1-2             [1, 32, 1000]             [1, 32, 1000, 1]          --                        --
├─Rearrange (dimshuffle): 1-3            [1, 32, 1000, 1]          [1, 1, 1000, 32]          --                        --
├─CombinedConv (conv_time_spat): 1-4     [1, 1, 1000, 32]          [1, 40, 976, 1]           52,240                    --
├─BatchNorm2d (bnorm): 1-5               [1, 40, 976, 1]           [1, 40, 976, 1]           80                        --
├─Expression (conv_nonlin_exp): 1-6      [1, 40, 976, 1]           [1, 40, 976, 1]           --                        --
├─AvgPool2d (pool): 1-7                  [1, 40, 976, 1]           [1, 40, 61, 1]            --                        [75, 1]
├─SafeLog (pool_nonlin_exp): 1-8         [1, 40, 61, 1]            [1, 40, 61, 1]            --                        --
├─Dropout (drop): 1-9                    [1, 40, 61, 1]            [1, 40, 61, 1]            --                        --
├─Sequential (final_layer): 1-10         [1, 40, 61, 1]            [1, 2]                    --                        --
│    └─Conv2d (conv_classifier): 2-1     [1, 40, 61, 1]            [1, 2, 1, 1]              4,882                     [61, 1]
│    └─Expression (squeeze): 2-2         [1, 2, 1, 1]              [1, 2]                    --                        --
============================================================================================================================================
Total params: 57,202
Trainable params: 57,202
Non-trainable params: 0
Total mult-adds (M): 0.00
============================================================================================================================================
Input size (MB): 0.13
Forward/backward pass size (MB): 0.31
Params size (MB): 0.02
Estimated Total Size (MB): 0.46
============================================================================================================================================

Loading your own data with MNE#

In this tutorial, we demonstrate how to train the model on MNE data. MNE is quite a popular library for EEG data analysis as it provides methods to load data from many different file formats and a large collection of algorithms to preprocess it. However, Braindecode is not limited to MNE and can be used with numpy arrays or PyTorch tensors/datasets.

For this example, we generate some random data containing 100 examples with each 3 channels and 1024 time points. We also generate some random labels for our data that simulate a 4-class classification problem.

import mne
import numpy as np

info = mne.create_info(ch_names=["C3", "C4", "Cz"], sfreq=256.0, ch_types="eeg")
X = np.random.randn(100, 3, 1024)  # 100 epochs, 3 channels, 4 seconds (@256Hz)
epochs = mne.EpochsArray(X, info=info)
y = np.random.randint(0, 4, size=100)  # 4 classes
print(epochs)
Not setting metadata
100 matching events found
No baseline correction applied
0 projection items activated
<EpochsArray | 100 events (all good), 0 – 3.996 s (baseline off), ~2.4 MB, data loaded,
 '1': 100>

Training your model (scikit-learn compatible)#

Now that you know which model you want to use, you know how to instantiate it, and that we have some fake data, it is time to train the model!

Note

Skorch is a library that allows you to wrap any PyTorch module into a scikit-learn-compatible classifier or regressor. Braindecode provides wrappers that inherit form the original Skorch ones and simply implement a few additional features that facilitate the use of Braindecode models.

To train a Braindecode model, the easiest way is by using braindecode’s Skorch wrappers. These wrappers are braindecode.EEGClassifier and braindecode.EEGRegressor. As our fake data is a classification task, we will use the former.

The wrapper braindecode.EEGClassifier expects a model class as its first argument but to facilitate the usage, you can also simply pass the name of any braindecode model as a string. The wrapper automatically finds and instantiates the model for you. If you want to pass parameters to your model, you can give them to the wrapper with the prefix module__.

from skorch.dataset import ValidSplit
from braindecode import EEGClassifier

net = EEGClassifier(
    "ShallowFBCSPNet",
    module__final_conv_length="auto",
    train_split=ValidSplit(0.2),
    # To train a neural network you need validation split, here, we use 20%.
)

In this example, we passed one additional parameter to the wrapper: module__final_conv_length that will be forwarded to the model (without the prefix module__).

We also note that the parameters n_chans, n_times and n_outputs were not specified even if braindecode.ShallowFBCSPNet needs them to be initialized. This is because the wrapper will automatically infer them, along with some other signal-related parameters, from the input data at training time.

Now that we have our model wrapped in a scikit-learn-compatible classifier, we can train it by simply calling the fit method:

  epoch    valid_acc    valid_loss     dur
-------  -----------  ------------  ------
      1       0.2500        2.3901  0.0184
      2       0.2500        2.3901  0.0093
      3       0.2500        2.3901  0.0088
      4       0.2500        2.3901  0.0085
      5       0.2500        2.3901  0.0085
      6       0.2500        2.3901  0.0084
      7       0.2500        2.3901  0.0084
      8       0.2500        2.3901  0.0084
      9       0.2500        2.3901  0.0085
     10       0.2500        2.3901  0.0083

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=============================================================================================================================================
  Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
  ============================================================================================================================================
  ShallowFBCSPNet (ShallowFBCSPNet)        [1, 3, 1024]              [1, 4]                    --                        --
  ├─SafeLog (pool_nonlin_exp): 1-1         [1, 3, 1024]              [1, 3, 1024]              --                        --
  ├─Ensure4d (ensuredims): 1-2             [1, 3, 1024]              [1, 3, 1024, 1]           --                        --
  ├─Rearrange (dimshuffle): 1-3            [1, 3, 1024, 1]           [1, 1, 1024, 3]           --                        --
  ├─CombinedConv (conv_time_spat): 1-4     [1, 1, 1024, 3]           [1, 40, 1000, 1]          5,840                     --
  ├─BatchNorm2d (bnorm): 1-5               [1, 40, 1000, 1]          [1, 40, 1000, 1]          80                        --
  ├─Expression (conv_nonlin_exp): 1-6      [1, 40, 1000, 1]          [1, 40, 1000, 1]          --                        --
  ├─AvgPool2d (pool): 1-7                  [1, 40, 1000, 1]          [1, 40, 62, 1]            --                        [75, 1]
  ├─SafeLog (pool_nonlin_exp): 1-8         [1, 40, 62, 1]            [1, 40, 62, 1]            --                        --
  ├─Dropout (drop): 1-9                    [1, 40, 62, 1]            [1, 40, 62, 1]            --                        --
  ├─Sequential (final_layer): 1-10         [1, 40, 62, 1]            [1, 4]                    --                        --
  │    └─Conv2d (conv_classifier): 2-1     [1, 40, 62, 1]            [1, 4, 1, 1]              9,924                     [62, 1]
  │    └─Expression (squeeze): 2-2         [1, 4, 1, 1]              [1, 4]                    --                        --
  ============================================================================================================================================
  Total params: 15,844
  Trainable params: 15,844
  Non-trainable params: 0
  Total mult-adds (M): 0.01
  ============================================================================================================================================
  Input size (MB): 0.01
  Forward/backward pass size (MB): 0.32
  Params size (MB): 0.04
  Estimated Total Size (MB): 0.37
  ============================================================================================================================================,
)

The pre-trained model is accessible via the module_ attribute:

print(net.module_)
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 3, 1024]              [1, 4]                    --                        --
├─SafeLog (pool_nonlin_exp): 1-1         [1, 3, 1024]              [1, 3, 1024]              --                        --
├─Ensure4d (ensuredims): 1-2             [1, 3, 1024]              [1, 3, 1024, 1]           --                        --
├─Rearrange (dimshuffle): 1-3            [1, 3, 1024, 1]           [1, 1, 1024, 3]           --                        --
├─CombinedConv (conv_time_spat): 1-4     [1, 1, 1024, 3]           [1, 40, 1000, 1]          5,840                     --
├─BatchNorm2d (bnorm): 1-5               [1, 40, 1000, 1]          [1, 40, 1000, 1]          80                        --
├─Expression (conv_nonlin_exp): 1-6      [1, 40, 1000, 1]          [1, 40, 1000, 1]          --                        --
├─AvgPool2d (pool): 1-7                  [1, 40, 1000, 1]          [1, 40, 62, 1]            --                        [75, 1]
├─SafeLog (pool_nonlin_exp): 1-8         [1, 40, 62, 1]            [1, 40, 62, 1]            --                        --
├─Dropout (drop): 1-9                    [1, 40, 62, 1]            [1, 40, 62, 1]            --                        --
├─Sequential (final_layer): 1-10         [1, 40, 62, 1]            [1, 4]                    --                        --
│    └─Conv2d (conv_classifier): 2-1     [1, 40, 62, 1]            [1, 4, 1, 1]              9,924                     [62, 1]
│    └─Expression (squeeze): 2-2         [1, 4, 1, 1]              [1, 4]                    --                        --
============================================================================================================================================
Total params: 15,844
Trainable params: 15,844
Non-trainable params: 0
Total mult-adds (M): 0.01
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.32
Params size (MB): 0.04
Estimated Total Size (MB): 0.37
============================================================================================================================================

And we can see that all the following parameters were automatically inferred from the training data:

print(
    f"{net.module_.n_chans=}\n{net.module_.n_times=}\n{net.module_.n_outputs=}"
    f"\n{net.module_.input_window_seconds=}\n{net.module_.sfreq=}\n{net.module_.chs_info=}"
)
net.module_.n_chans=3
net.module_.n_times=1024
net.module_.n_outputs=4
net.module_.input_window_seconds=4.0
net.module_.sfreq=256.0
net.module_.chs_info=[{'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C3', 'scanno': 1, 'logno': 1}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C4', 'scanno': 2, 'logno': 2}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'Cz', 'scanno': 3, 'logno': 3}]

Depending on the type of data used for training, some parameters might not be possible to infer. For example if you pass a numpy array or a braindecode.dataset.WindowsDataset with target_from="metadata",

then only n_chans, n_times and n_outputs will be inferred.

And if you pass other types of datasets, only n_chans and n_times will be inferred. In these case, you will have to pass the missing parameters manually (with the prefix module__).

References#

Total running time of the script: (0 minutes 2.078 seconds)

Estimated memory usage: 709 MB

Gallery generated by Sphinx-Gallery