Note
Go to the end to download the full example code
Simple training on MNE epochs#
The braindecode library gives you access to a large number of neural network architectures that were developed for EEG data decoding. This tutorial will show you how you can easily use any of these models to decode your own data. In particular, we assume that have your data in an MNE format and want to train one of the Braindecode models on it.
This example covers:
# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)
Finding the model you want#
Exploring the braindecode online documentation#
Let’s suppose you recently stumbled upon the Schirrmeister 2017 article [1]. In this article, the authors mention that their novel architecture ShallowConvNet is performing well on the BCI Competition IV 2a dataset and you would like to use it on your own data. Fortunately, the authors also mentioned they published their architecture on Braindecode!
In order to use this architecture, you first need to find what is its exact name in Braindecode. To do so, you can visit the Braindecode online documentation which lists all the available models.
Models list: https://braindecode.org/stable/api.html#models
Alternatively, the API also provide a dictionary with all available models:
from braindecode.models.util import models_dict
print(f'All the Braindecode models:\n{list(models_dict.keys())}')
All the Braindecode models:
['ATCNet', 'Deep4Net', 'DeepSleepNet', 'EEGConformer', 'EEGITNet', 'EEGInception', 'EEGInceptionERP', 'EEGInceptionMI', 'EEGNetv1', 'EEGNetv4', 'EEGResNet', 'HybridNet', 'ShallowFBCSPNet', 'SleepStagerBlanco2020', 'SleepStagerChambon2018', 'SleepStagerEldele2021', 'TCN', 'TIDNet', 'USleep']
After your investigation, you found out that the model you are looking for is
ShallowFBCSPNet
. You can now import it from Braindecode:
from braindecode.models import ShallowFBCSPNet
Examining the model#
Now that you found your model, you must check which parameters it expects.
You can find this information either in the online documentation here:
braindecode.models.ShallowFBCSPNet
or directly in the module’s docstring:
print(ShallowFBCSPNet.__doc__)
Shallow ConvNet model from Schirrmeister et al 2017.
Model described in [Schirrmeister2017]_.
Parameters
----------
n_chans : int
Number of EEG channels.
n_outputs : int
Number of outputs of the model. This is the number of classes
in the case of classification.
n_times : int
Number of time samples of the input window.
n_filters_time: int
Number of temporal filters.
filter_time_length: int
Length of the temporal filter.
n_filters_spat: int
Number of spatial filters.
pool_time_length: int
Length of temporal pooling filter.
pool_time_stride: int
Length of stride between temporal pooling filters.
final_conv_length: int | str
Length of the final convolution layer.
If set to "auto", length of the input signal must be specified.
conv_nonlin: callable
Non-linear function to be used after convolution layers.
pool_mode: str
Method to use on pooling layers. "max" or "mean".
pool_nonlin: callable
Non-linear function to be used after pooling layers.
split_first_layer: bool
Split first layer into temporal and spatial layers (True) or just use temporal (False).
There would be no non-linearity between the split layers.
batch_norm: bool
Whether to use batch normalisation.
batch_norm_alpha: float
Momentum for BatchNorm2d.
drop_prob: float
Dropout probability.
chs_info : list of dict
Information about each individual EEG channel. This should be filled with
``info["chs"]``. Refer to :class:`mne.Info` for more details.
input_window_seconds : float
Length of the input window in seconds.
sfreq : float
Sampling frequency of the EEG recordings.
in_chans : int
Alias for `n_chans`.
n_classes: int
Alias for `n_outputs`.
input_window_samples: int | None
Alias for `n_times`.
add_log_softmax: bool
Whether to use log-softmax non-linearity as the output function.
LogSoftmax final layer will be removed in the future.
Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
Check the documentation of the torch.nn loss functions:
https://pytorch.org/docs/stable/nn.html#loss-functions.
Raises
------
ValueError: If some input signal-related parameters are not specified
and can not be inferred.
FutureWarning: If add_log_softmax is True, since LogSoftmax final layer
will be removed in the future.
Notes
-----
If some input signal-related parameters are not specified,
there will be an attempt to infer them from the other parameters.
References
----------
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
& Ball, T. (2017).
Deep learning with convolutional neural networks for EEG decoding and
visualization.
Human Brain Mapping , Aug. 2017.
Online: http://dx.doi.org/10.1002/hbm.23730
Additionally, you might be interested in visualizing the model’s architecture.
This can be done by initializing the model and calling its __str__()
method.
To initialize it, we need to specify some parameters that we set at random
values for now:
model = ShallowFBCSPNet(
n_chans=32,
n_times=1000,
n_outputs=2,
final_conv_length='auto',
)
print(model)
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
warnings.warn("LogSoftmax final layer will be removed! " +
============================================================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape Param # Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet) [1, 32, 1000] [1, 2] -- --
├─Ensure4d (ensuredims): 1-1 [1, 32, 1000] [1, 32, 1000, 1] -- --
├─Rearrange (dimshuffle): 1-2 [1, 32, 1000, 1] [1, 1, 1000, 32] -- --
├─CombinedConv (conv_time_spat): 1-3 [1, 1, 1000, 32] [1, 40, 976, 1] 52,240 --
├─BatchNorm2d (bnorm): 1-4 [1, 40, 976, 1] [1, 40, 976, 1] 80 --
├─Expression (conv_nonlin_exp): 1-5 [1, 40, 976, 1] [1, 40, 976, 1] -- --
├─AvgPool2d (pool): 1-6 [1, 40, 976, 1] [1, 40, 61, 1] -- [75, 1]
├─Expression (pool_nonlin_exp): 1-7 [1, 40, 61, 1] [1, 40, 61, 1] -- --
├─Dropout (drop): 1-8 [1, 40, 61, 1] [1, 40, 61, 1] -- --
├─Sequential (final_layer): 1-9 [1, 40, 61, 1] [1, 2] -- --
│ └─Conv2d (conv_classifier): 2-1 [1, 40, 61, 1] [1, 2, 1, 1] 4,882 [61, 1]
│ └─LogSoftmax (logsoftmax): 2-2 [1, 2, 1, 1] [1, 2, 1, 1] -- --
│ └─Expression (squeeze): 2-3 [1, 2, 1, 1] [1, 2] -- --
============================================================================================================================================
Total params: 57,202
Trainable params: 57,202
Non-trainable params: 0
Total mult-adds (M): 0.00
============================================================================================================================================
Input size (MB): 0.13
Forward/backward pass size (MB): 0.31
Params size (MB): 0.02
Estimated Total Size (MB): 0.46
============================================================================================================================================
Loading your own data with MNE#
In this tutorial, we demonstrate how to train the model on MNE data. MNE is quite a popular library for EEG data analysis as it provides methods to load data from many different file formats and a large collection of algorithms to preprocess it. However, Braindecode is not limited to MNE and can be used with numpy arrays or PyTorch tensors/datasets.
For this example, we generate some random data containing 100 examples with each 3 channels and 1024 time points. We also generate some random labels for our data that simulate a 4-class classification problem.
import mne
import numpy as np
info = mne.create_info(ch_names=['C3', 'C4', 'Cz'], sfreq=256., ch_types='eeg')
X = np.random.randn(100, 3, 1024) # 100 epochs, 3 channels, 4 seconds (@256Hz)
epochs = mne.EpochsArray(X, info=info)
y = np.random.randint(0, 4, size=100) # 4 classes
print(epochs)
Not setting metadata
100 matching events found
No baseline correction applied
0 projection items activated
<EpochsArray | 100 events (all good), 0 – 3.99609 s, baseline off, ~2.4 MB, data loaded,
'1': 100>
Training your model (scikit-learn compatible)#
Now that you know which model you want to use, you know how to instantiate it, and that we have some fake data, it is time to train the model!
Note
Skorch is a library that allows you to wrap any PyTorch module into a scikit-learn-compatible classifier or regressor. Braindecode provides wrappers that inherit form the original Skorch ones and simply implement a few additional features that facilitate the use of Braindecode models.
To train a Braindecode model, the easiest way is by using braindecode’s
Skorch wrappers. These wrappers are braindecode.EEGClassifier
and
braindecode.EEGRegressor
. As our fake data is a classification task,
we will use the former.
The wrapper braindecode.EEGClassifier
expects a model class as its first argument but
to facilitate the usage, you can also simply pass the name of any braindecode model as a string.
The wrapper automatically finds and instantiates the model for you.
If you want to pass parameters to your model, you can give them to the wrapper
with the prefix module__
.
from skorch.dataset import ValidSplit
from braindecode import EEGClassifier
net = EEGClassifier(
'ShallowFBCSPNet',
module__final_conv_length='auto',
train_split=ValidSplit(0.2),
# To train a neural network you need validation split, here, we use 20%.
)
In this example, we passed one additional parameter to the wrapper: module__final_conv_length
that will be forwarded to the model (without the prefix module__
).
We also note that the parameters n_chans
, n_times
and n_outputs
were not specified
even if braindecode.ShallowFBCSPNet
needs them to be initialized. This is because the
wrapper will automatically infer them, along with some other signal-related parameters,
from the input data at training time.
Now that we have our model wrapped in a scikit-learn-compatible classifier,
we can train it by simply calling the fit
method:
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
warnings.warn("LogSoftmax final layer will be removed! " +
epoch valid_acc valid_loss dur
------- ----------- ------------ ------
1 0.4000 18.7568 0.0190
2 0.4000 18.7568 0.0111
3 0.4000 18.7568 0.0109
4 0.4000 18.7568 0.0097
5 0.4000 18.7568 0.0068
6 0.4000 18.7568 0.0069
7 0.4000 18.7568 0.0066
8 0.4000 18.7568 0.0066
9 0.4000 18.7568 0.0065
10 0.4000 18.7568 0.0065
<class 'braindecode.classifier.EEGClassifier'>[initialized](
module_=============================================================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape Param # Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet) [1, 3, 1024] [1, 4] -- --
├─Ensure4d (ensuredims): 1-1 [1, 3, 1024] [1, 3, 1024, 1] -- --
├─Rearrange (dimshuffle): 1-2 [1, 3, 1024, 1] [1, 1, 1024, 3] -- --
├─CombinedConv (conv_time_spat): 1-3 [1, 1, 1024, 3] [1, 40, 1000, 1] 5,840 --
├─BatchNorm2d (bnorm): 1-4 [1, 40, 1000, 1] [1, 40, 1000, 1] 80 --
├─Expression (conv_nonlin_exp): 1-5 [1, 40, 1000, 1] [1, 40, 1000, 1] -- --
├─AvgPool2d (pool): 1-6 [1, 40, 1000, 1] [1, 40, 62, 1] -- [75, 1]
├─Expression (pool_nonlin_exp): 1-7 [1, 40, 62, 1] [1, 40, 62, 1] -- --
├─Dropout (drop): 1-8 [1, 40, 62, 1] [1, 40, 62, 1] -- --
├─Sequential (final_layer): 1-9 [1, 40, 62, 1] [1, 4] -- --
│ └─Conv2d (conv_classifier): 2-1 [1, 40, 62, 1] [1, 4, 1, 1] 9,924 [62, 1]
│ └─LogSoftmax (logsoftmax): 2-2 [1, 4, 1, 1] [1, 4, 1, 1] -- --
│ └─Expression (squeeze): 2-3 [1, 4, 1, 1] [1, 4] -- --
============================================================================================================================================
Total params: 15,844
Trainable params: 15,844
Non-trainable params: 0
Total mult-adds (M): 0.01
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.32
Params size (MB): 0.04
Estimated Total Size (MB): 0.37
============================================================================================================================================,
)
The pre-trained model is accessible via the module_
attribute:
print(net.module_)
============================================================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape Param # Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet) [1, 3, 1024] [1, 4] -- --
├─Ensure4d (ensuredims): 1-1 [1, 3, 1024] [1, 3, 1024, 1] -- --
├─Rearrange (dimshuffle): 1-2 [1, 3, 1024, 1] [1, 1, 1024, 3] -- --
├─CombinedConv (conv_time_spat): 1-3 [1, 1, 1024, 3] [1, 40, 1000, 1] 5,840 --
├─BatchNorm2d (bnorm): 1-4 [1, 40, 1000, 1] [1, 40, 1000, 1] 80 --
├─Expression (conv_nonlin_exp): 1-5 [1, 40, 1000, 1] [1, 40, 1000, 1] -- --
├─AvgPool2d (pool): 1-6 [1, 40, 1000, 1] [1, 40, 62, 1] -- [75, 1]
├─Expression (pool_nonlin_exp): 1-7 [1, 40, 62, 1] [1, 40, 62, 1] -- --
├─Dropout (drop): 1-8 [1, 40, 62, 1] [1, 40, 62, 1] -- --
├─Sequential (final_layer): 1-9 [1, 40, 62, 1] [1, 4] -- --
│ └─Conv2d (conv_classifier): 2-1 [1, 40, 62, 1] [1, 4, 1, 1] 9,924 [62, 1]
│ └─LogSoftmax (logsoftmax): 2-2 [1, 4, 1, 1] [1, 4, 1, 1] -- --
│ └─Expression (squeeze): 2-3 [1, 4, 1, 1] [1, 4] -- --
============================================================================================================================================
Total params: 15,844
Trainable params: 15,844
Non-trainable params: 0
Total mult-adds (M): 0.01
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.32
Params size (MB): 0.04
Estimated Total Size (MB): 0.37
============================================================================================================================================
And we can see that all the following parameters were automatically inferred from the training data:
print(f'{net.module_.n_chans=}\n{net.module_.n_times=}\n{net.module_.n_outputs=}'
f'\n{net.module_.input_window_seconds=}\n{net.module_.sfreq=}\n{net.module_.chs_info=}')
net.module_.n_chans=3
net.module_.n_times=1024
net.module_.n_outputs=4
net.module_.input_window_seconds=4.0
net.module_.sfreq=256.0
net.module_.chs_info=[{'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C3', 'scanno': 1, 'logno': 1}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'C4', 'scanno': 2, 'logno': 2}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 2 (FIFFV_EEG_CH), 'coil_type': 1 (FIFFV_COIL_EEG), 'unit': 107 (FIFF_UNIT_V), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'Cz', 'scanno': 3, 'logno': 3}]
Depending on the type of data used for training, some parameters might not be
possible to infer. For example if you pass a numpy array or a
braindecode.dataset.WindowsDataset
with target_from="metadata"
,
then only
n_chans
,n_times
andn_outputs
will be inferred.
And if you pass other types of datasets, only n_chans
and n_times
will be inferred.
In these case, you will have to pass the missing parameters manually
(with the prefix module__
).
References#
Total running time of the script: (0 minutes 1.908 seconds)
Estimated memory usage: 41 MB