Hyperparameter tuning with scikit-learn#

The braindecode provides some compatibility with scikit-learn. This allows us to use scikit-learn functionality to find the best hyperparameters for our model. This is especially useful to tune hyperparameters or parameters for one decoding task or a specific dataset.

In this tutorial, we will use the standard decoding approach to show the impact of the learning rate and dropout probability on the model’s performance.

Loading and preprocessing the dataset#

Loading#

First, we load the data. In this tutorial, we use the functionality of braindecode to load datasets via MOABB [2] to load the BCI Competition IV 2a data [3].

Note

To load your own datasets either via mne or from preprocessed X/y numpy arrays, see MNE Dataset Tutorial and Numpy Dataset Tutorial.

from braindecode.datasets.moabb import MOABBDataset

subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
BNCI2014001 has been renamed to BNCI2014_001. BNCI2014001 will be removed in version 1.1.
The dataset class name 'BNCI2014001' must be an abbreviation of its code 'BNCI2014-001'. See moabb.datasets.base.is_abbrev for more information.
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]

Preprocessing#

In this example, preprocessing includes signal rescaling, the bandpass filtering (low and high cut-off frequencies are 4 and 38 Hz) and the standardization using the exponential moving mean and variance. You can either apply functions provided by mne.Raw or mne.Epochs or apply your own functions, either to the MNE object or the underlying numpy array.

Note

These prepocessings are now directly applied to the loaded data, and not on-the-fly applied as transformations in PyTorch-libraries like torchvision.

from braindecode.preprocessing.preprocess import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)
from numpy import multiply

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),
    # Keep EEG sensors
    Preprocessor(lambda data: multiply(data, factor)),  # Convert from V to uV
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),
    # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,
        # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Preprocess the data
preprocess(dataset, preprocessors, n_jobs=-1)
/home/runner/work/braindecode/braindecode/braindecode/preprocessing/preprocess.py:69: UserWarning: Preprocessing choices with lambda functions cannot be saved.
  warn("Preprocessing choices with lambda functions cannot be saved.")

<braindecode.datasets.moabb.MOABBDataset object at 0x7fe7ea370d90>

Extraction of the Compute Windows#

Extraction of the Windows#

Extraction of the trials (windows) from the time series is based on the events inside the dataset. One event is the demarcation of the stimulus or the beginning of the trial. In this example, we want to analyse 0.5 [s] long before the corresponding event and the duration of the event itself. #Therefore, we set the trial_start_offset_seconds to -0.5 [s] and the trial_stop_offset_seconds to 0 [s].

We extract from the dataset the sampling frequency, which is the same for all datasets in this case, and we tested it.

Note

The trial_start_offset_seconds and trial_stop_offset_seconds are defined in seconds and need to be converted into samples (multiplication with the sampling frequency), relative to the event. This variable is dataset dependent.

from braindecode.preprocessing.windowers import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

Split dataset into train and valid#

We can easily split the dataset using additional info stored in the description attribute, in this case session column. We select 0train for training and 1test for evaluation.

splitted = windows_dataset.split("session")
train_set = splitted["0train"]  # Session train
eval_set = splitted["1test"]  # Session evaluation

Create model#

Now we create the deep learning model! Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the ShallowFBCSPNet model from Deep learning with convolutional neural networks for EEG decoding and visualization [4]. These models are pure PyTorch deep learning models, therefore to use your own model, it just has to be a normal PyTorch nn.Module.

from functools import partial
import torch
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet

# check if GPU is available, if True chooses to use it
cuda = torch.cuda.is_available()
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 20200220  # random seed to make results reproducible
# Set random seed to be able to reproduce results
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
# Extract number of chans and time steps from dataset
n_chans = train_set[0][0].shape[0]
n_times = train_set[0][0].shape[1]

# To analyze the impact of the different parameters inside the torch model, we
# need to create partial initialisations. This is because the
# GridSearchCV of scikit-learn will try to initialize the model with the
# parameters we want to tune. If we do not do this, the GridSearchCV will
# try to initialize the model with the parameters we want to tune but
# without the parameters we do not want to tune. This will result in an
# error.
model = partial(
    ShallowFBCSPNet,
    n_chans=n_chans,
    n_outputs=n_classes,
    n_times=n_times,
    final_conv_length="auto",
)

# Send model to GPU
if cuda and hasattr(model, "cuda"):
    model.cuda()

Training#

Now we train the network! EEGClassifier is a Braindecode object responsible for managing the training of neural networks. It inherits from skorch.NeuralNetClassifier, so the training logic is the same as in Skorch.

from skorch.callbacks import LRScheduler
from skorch.dataset import ValidSplit
from braindecode import EEGClassifier

batch_size = 16
n_epochs = 2

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    optimizer__lr=[],  # This will be handled by GridSearchCV
    batch_size=batch_size,
    train_split=ValidSplit(0.2, random_state=seed),
    callbacks=[
        "accuracy",
        ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
    ],
    device=device,
)

We use scikit-learn GridSearchCV to tune hyperparameters. To be able to do this, we slice the braindecode datasets that by default return a 3-tuple to return X and y, respectively.

Note

The KFold object splits the datasets based on their length which corresponds to the number of compute windows. In this (trialwise) example this is fine to do. In a cropped setting this is not advisable since this might split compute windows of a single trial into both train and valid set.

from sklearn.model_selection import GridSearchCV, KFold
from skorch.helper import SliceDataset
from numpy import array
import pandas as pd

train_X = SliceDataset(train_set, idx=0)
train_y = array([y for y in SliceDataset(train_set, idx=1)])
cv = KFold(n_splits=2, shuffle=True, random_state=42)

learning_rates = [0.00625, 0.0000625]
drop_probs = [0.2, 0.5, 0.8]

fit_params = {"epochs": n_epochs}
param_grid = {"optimizer__lr": learning_rates, "module__drop_prob": drop_probs}

# By setting n_jobs=-1, grid search is performed
# with all the processors, in this case the output of the training
# process is not printed sequentially
search = GridSearchCV(
    estimator=clf,
    param_grid=param_grid,
    cv=cv,
    return_train_score=True,
    scoring="accuracy",
    refit=True,
    verbose=1,
    error_score="raise",
    n_jobs=1,
)

search.fit(train_X, train_y, **fit_params)

# Extract the results into a DataFrame
search_results = pd.DataFrame(search.cv_results_)
Fitting 2 folds for each of 6 candidates, totalling 12 fits
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2609        1.4822       0.3448            0.3448       12.8587  0.0063  0.5860
      2            0.2609        0.7629       0.3448            0.3448        3.4101  0.0000  0.5789
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2348        1.4107       0.3448            0.3448       12.3777  0.0063  0.8335
      2            0.2783        0.7427       0.3448            0.3448        4.0815  0.0000  0.5711
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2783        1.4500       0.0690            0.0690        2.0014  0.0001  0.5755
      2            0.3043        1.3866       0.1379            0.1379        1.5656  0.0000  0.5461
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2348        1.4712       0.3448            0.3448        1.7099  0.0001  0.5457
      2            0.2957        1.3762       0.2759            0.2759        1.4668  0.0000  0.5505
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2609        1.5884       0.3448            0.3448        2.5505  0.0063  0.5476
      2            0.4609        0.8677       0.3448            0.3448        1.6186  0.0000  0.5691
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2348        1.6299       0.3448            0.3448        5.0177  0.0063  0.5757
      2            0.3739        0.9309       0.4138            0.4138        1.9423  0.0000  0.5747
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2957        1.4859       0.2759            0.2759        1.4425  0.0001  0.5454
      2            0.3304        1.4221       0.2414            0.2414        1.4645  0.0000  0.5453
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2609        1.5369       0.3448            0.3448        1.3557  0.0001  0.5474
      2            0.3304        1.4058       0.4828            0.4828        1.3220  0.0000  0.5993
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2348        1.7556       0.2759            0.2759        7.6143  0.0063  0.5768
      2            0.2348        1.4904       0.2759            0.2759        2.4326  0.0000  0.5562
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.3739        1.6044       0.3793            0.3793        2.9512  0.0063  0.5533
      2            0.3913        1.3977       0.4483            0.4483        1.5505  0.0000  0.5543
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2696        1.5889       0.2414            0.2414        2.0679  0.0001  0.5485
      2            0.2435        1.5474       0.2069            0.2069        1.5949  0.0000  0.5730
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2870        1.6018       0.1379            0.1379        1.9570  0.0001  0.5749
      2            0.2696        1.4638       0.2069            0.2069        1.5792  0.0000  0.5499
Can only infer signal shape of array-like and Datasets, got <class 'skorch.helper.SliceDataset'>.
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2783        1.4333       0.2586            0.2586        1.4455  0.0001  1.1230
      2            0.3174        1.3830       0.2586            0.2586        1.4360  0.0000  1.1202

Plotting the results#

import matplotlib.pyplot as plt
import seaborn as sns

# Create a pivot table for the heatmap
pivot_table = search_results.pivot(
    index="param_optimizer__lr",
    columns="param_module__drop_prob",
    values="mean_test_score",
)
# Create the heatmap
fig, ax = plt.subplots()
sns.heatmap(pivot_table, annot=True, fmt=".3f", cmap="YlGnBu", cbar=True)
plt.title("Grid Search Mean Test Scores")
plt.ylabel("Learning Rate")
plt.xlabel("Dropout Probability")
plt.tight_layout()
plt.show()
Grid Search Mean Test Scores

Get the best hyperparameters#

best_run = search_results[search_results["rank_test_score"] == 1].squeeze()
print(
    f"Best hyperparameters were {best_run['params']} which gave a validation "
    f"accuracy of {best_run['mean_test_score'] * 100:.2f}% (training "
    f"accuracy of {best_run['mean_train_score'] * 100:.2f}%)."
)

eval_X = SliceDataset(eval_set, idx=0)
eval_y = SliceDataset(eval_set, idx=1)
score = search.score(eval_X, eval_y)
print(f"Eval accuracy is {score * 100:.2f}%.")
Best hyperparameters were {'module__drop_prob': 0.2, 'optimizer__lr': 6.25e-05} which gave a validation accuracy of 30.90% (training accuracy of 28.12%).
Eval accuracy is 28.47%.

References#

Total running time of the script: (0 minutes 29.453 seconds)

Estimated memory usage: 1016 MB

Gallery generated by Sphinx-Gallery