Hyperparameter tuning with scikit-learn#

The braindecode provides some compatibility with scikit-learn. This allows us to use scikit-learn functionality to find the best hyperparameters for our model. This is especially useful to tune hyperparameters or parameters for one decoding task or a specific dataset.

In this tutorial, we will use the standard decoding approach to show the impact of the learning rate and dropout probability on the model’s performance.

Loading and preprocessing the dataset#

Loading#

First, we load the data. In this tutorial, we use the functionality of braindecode to load datasets via MOABB [2] to load the BCI Competition IV 2a data [3].

Note

To load your own datasets either via mne or from preprocessed X/y numpy arrays, see MNE Dataset Tutorial and Numpy Dataset Tutorial.

from braindecode.datasets.moabb import MOABBDataset

subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
BNCI2014001 has been renamed to BNCI2014_001. BNCI2014001 will be removed in version 1.1.
The dataset class name 'BNCI2014001' must be an abbreviation of its code 'BNCI2014-001'. See moabb.datasets.base.is_abbrev for more information.

Preprocessing#

In this example, preprocessing includes signal rescaling, the bandpass filtering (low and high cut-off frequencies are 4 and 38 Hz) and the standardization using the exponential moving mean and variance. You can either apply functions provided by mne.Raw or mne.Epochs or apply your own functions, either to the MNE object or the underlying numpy array.

Note

These prepocessings are now directly applied to the loaded data, and not on-the-fly applied as transformations in PyTorch-libraries like torchvision.

from braindecode.preprocessing.preprocess import (
    exponential_moving_standardize, preprocess, Preprocessor)
from numpy import multiply

low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),
    # Keep EEG sensors
    Preprocessor(lambda data: multiply(data, factor)),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
    # Bandpass filter
    Preprocessor(exponential_moving_standardize,
                 # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Preprocess the data
preprocess(dataset, preprocessors, n_jobs=-1)
/home/runner/work/braindecode/braindecode/braindecode/preprocessing/preprocess.py:55: UserWarning: Preprocessing choices with lambda functions cannot be saved.
  warn('Preprocessing choices with lambda functions cannot be saved.')

<braindecode.datasets.moabb.MOABBDataset object at 0x7f4542305600>

Extraction of the Compute Windows#

Extraction of the Windows#

Extraction of the trials (windows) from the time series is based on the events inside the dataset. One event is the demarcation of the stimulus or the beginning of the trial. In this example, we want to analyse 0.5 [s] long before the corresponding event and the duration of the event itself. #Therefore, we set the trial_start_offset_seconds to -0.5 [s] and the trial_stop_offset_seconds to 0 [s].

We extract from the dataset the sampling frequency, which is the same for all datasets in this case, and we tested it.

Note

The trial_start_offset_seconds and trial_stop_offset_seconds are defined in seconds and need to be converted into samples (multiplication with the sampling frequency), relative to the event. This variable is dataset dependent.

from braindecode.preprocessing.windowers import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

Split dataset into train and valid#

We can easily split the dataset using additional info stored in the description attribute, in this case session column. We select 0train for training and 1test for evaluation.

splitted = windows_dataset.split('session')
train_set = splitted['0train']  # Session train
eval_set = splitted['1test']  # Session evaluation

Create model#

Now we create the deep learning model! Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the ShallowFBCSPNet model from Deep learning with convolutional neural networks for EEG decoding and visualization [4]. These models are pure PyTorch deep learning models, therefore to use your own model, it just has to be a normal PyTorch nn.Module.

from functools import partial
import torch
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet

# check if GPU is available, if True chooses to use it
cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 20200220  # random seed to make results reproducible
# Set random seed to be able to reproduce results
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
# Extract number of chans and time steps from dataset
n_chans = train_set[0][0].shape[0]
input_window_samples = train_set[0][0].shape[1]

# To analyze the impact of the different parameters inside the torch model, we
# need to create partial initialisations. This is because the
# GridSearchCV of scikit-learn will try to initialize the model with the
# parameters we want to tune. If we do not do this, the GridSearchCV will
# try to initialize the model with the parameters we want to tune but
# without the parameters we do not want to tune. This will result in an
# error.
model = partial(ShallowFBCSPNet, n_chans, n_classes,
                input_window_samples=input_window_samples,
                final_conv_length='auto', )

# Send model to GPU
if cuda:
    model.cuda()

Training#

Now we train the network! EEGClassifier is a Braindecode object responsible for managing the training of neural networks. It inherits from skorch.NeuralNetClassifier, so the training logic is the same as in Skorch.

from skorch.callbacks import LRScheduler
from skorch.dataset import ValidSplit
from braindecode import EEGClassifier

batch_size = 16
n_epochs = 2

clf = EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.AdamW,
    optimizer__lr=[],  # This will be handled by GridSearchCV
    batch_size=batch_size,
    train_split=ValidSplit(0.2, random_state=seed),
    callbacks=[
        "accuracy",
        ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)

We use scikit-learn GridSearchCV to tune hyperparameters. To be able to do this, we slice the braindecode datasets that by default return a 3-tuple to return X and y, respectively.

Note

The KFold object splits the datasets based on their length which corresponds to the number of compute windows. In this (trialwise) example this is fine to do. In a cropped setting this is not advisable since this might split compute windows of a single trial into both train and valid set.

from sklearn.model_selection import GridSearchCV, KFold
from skorch.helper import SliceDataset
from numpy import array
import pandas as pd

train_X = SliceDataset(train_set, idx=0)
train_y = array([y for y in SliceDataset(train_set, idx=1)])
cv = KFold(n_splits=2, shuffle=True, random_state=42)

learning_rates = [0.00625, 0.0000625]
drop_probs = [0.2, 0.5, 0.8]

fit_params = {'epochs': n_epochs}
param_grid = {
    'optimizer__lr': learning_rates,
    'module__drop_prob': drop_probs
}

# By setting n_jobs=-1, grid search is performed
# with all the processors, in this case the output of the training
# process is not printed sequentially
search = GridSearchCV(
    estimator=clf,
    param_grid=param_grid,
    cv=cv,
    return_train_score=True,
    scoring='accuracy',
    refit=True,
    verbose=1,
    error_score='raise',
    n_jobs=1,
)

search.fit(train_X, train_y, **fit_params)

# Extract the results into a DataFrame
search_results = pd.DataFrame(search.cv_results_)
Fitting 2 folds for each of 6 candidates, totalling 12 fits
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.5304        2.0774       0.1724            0.1724        3.0998  0.0063  0.5656
      2            0.6174        1.1537       0.2069            0.2069        2.8704  0.0000  0.5503
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.3565        2.5936       0.3448            0.3448        2.3154  0.0063  0.5648
      2            0.5130        1.1784       0.3448            0.3448        1.8569  0.0000  0.5300
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2348        1.5185       0.2759            0.2759        3.0333  0.0001  0.5275
      2            0.2348        1.4272       0.2759            0.2759        2.3804  0.0000  0.5491
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2000        1.6046       0.3103            0.3103        4.8912  0.0001  0.5509
      2            0.2000        1.3911       0.3103            0.3103        3.7741  0.0000  0.5318
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.5304        1.7828       0.1724            0.1724        2.9789  0.0063  0.5332
      2            0.5826        1.1250       0.1724            0.1724        2.5773  0.0000  0.5752
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2609        2.5664       0.3793            0.3793        3.1215  0.0063  0.5432
      2            0.4696        1.4854       0.4483            0.4483        1.7465  0.0000  0.5506
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2261        1.7070       0.3103            0.3103        2.1473  0.0001  0.5476
      2            0.2435        1.6492       0.3103            0.3103        1.8568  0.0000  0.5280
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2783        1.5837       0.2069            0.2069        2.4035  0.0001  0.5286
      2            0.2696        1.5464       0.2069            0.2069        2.1272  0.0000  0.5565
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2609        2.1099       0.3103            0.3103        3.2171  0.0063  0.5508
      2            0.4696        1.8860       0.3448            0.3448        1.9000  0.0000  0.5290
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.4087        2.1420       0.4138            0.4138        1.5093  0.0063  0.5286
      2            0.5130        1.5210       0.3793            0.3793        1.3900  0.0000  0.5512
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2348        2.1793       0.2759            0.2759        9.8493  0.0001  0.5529
      2            0.2348        1.9960       0.2759            0.2759        7.3377  0.0000  0.5363
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2000        2.2516       0.3103            0.3103        2.5048  0.0001  0.5304
      2            0.1913        2.1263       0.3448            0.3448        2.0232  0.0000  0.5541
Can only infer signal shape of numpy arrays or and Datasets, got <class 'skorch.helper.SliceDataset'>.
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------  ------
      1            0.2739        2.2403       0.2586            0.2586        6.1557  0.0063  1.0682
      2            0.5304        1.7535       0.3448            0.3448        1.9383  0.0000  1.0583

Plotting the results#

import matplotlib.pyplot as plt
import seaborn as sns


# Create a pivot table for the heatmap
pivot_table = search_results.pivot(index='param_optimizer__lr',
                                   columns='param_module__drop_prob',
                                   values='mean_test_score')
# Create the heatmap
fig, ax = plt.subplots()
sns.heatmap(pivot_table, annot=True, fmt=".3f",
            cmap="YlGnBu", cbar=True)
plt.title('Grid Search Mean Test Scores')
plt.ylabel('Learning Rate')
plt.xlabel('Dropout Probability')
plt.tight_layout()
plt.show()
Grid Search Mean Test Scores
/home/runner/work/braindecode/braindecode/examples/model_building/plot_hyperparameter_tuning_with_scikit-learn.py:332: FutureWarning: In a future version, the Index constructor will not infer numeric dtypes when passed object-dtype sequences (matching Series behavior)
  pivot_table = search_results.pivot(index='param_optimizer__lr',
/home/runner/work/braindecode/braindecode/examples/model_building/plot_hyperparameter_tuning_with_scikit-learn.py:332: FutureWarning: In a future version, the Index constructor will not infer numeric dtypes when passed object-dtype sequences (matching Series behavior)
  pivot_table = search_results.pivot(index='param_optimizer__lr',

Get the best hyperparameters#

best_run = search_results[search_results['rank_test_score'] == 1].squeeze()
print(
    f"Best hyperparameters were {best_run['params']} which gave a validation "
    f"accuracy of {best_run['mean_test_score'] * 100:.2f}% (training "
    f"accuracy of {best_run['mean_train_score'] * 100:.2f}%).")

eval_X = SliceDataset(eval_set, idx=0)
eval_y = SliceDataset(eval_set, idx=1)
score = search.score(eval_X, eval_y)
print(f"Eval accuracy is {score * 100:.2f}%.")
Best hyperparameters were {'module__drop_prob': 0.2, 'optimizer__lr': 0.00625} which gave a validation accuracy of 32.64% (training accuracy of 50.69%).
Eval accuracy is 34.72%.

References#

Total running time of the script: (0 minutes 28.100 seconds)

Estimated memory usage: 164 MB

Gallery generated by Sphinx-Gallery