Cropped Decoding on BCIC IV 2a Dataset#

Building on the Trialwise decoding, we now do more data-efficient cropped decoding!

In Braindecode, there are two supported configurations created for training models: trialwise decoding and cropped decoding. We will explain this visually by comparing trialwise to cropped decoding.

Trialwise decoding Cropped decoding

On the left, you see trialwise decoding:

  1. A complete trial is pushed through the network.

  2. The network produces a prediction.

  3. The prediction is compared to the target (label) for that trial to compute the loss.

On the right, you see cropped decoding:

  1. Instead of a complete trial, crops are pushed through the network.

  2. For computational efficiency, multiple neighbouring crops are pushed through the network simultaneously (these neighbouring crops are called compute windows)

  3. Therefore, the network produces multiple predictions (one per crop in the window)

  4. The individual crop predictions are AVERAGED before computing the loss function

This averaging of predictions of small sub-windows is the key difference between trialwise and cropped decoding. It was introduced in [1] and it impact on the parameters of the network.

It is important to note that the averaging of predictions is only done during training. During testing, the network is still applied to crops and the predictions are averaged afterwards.

Note

  • The network architecture implicitly defines the crop size (it is the receptive field size, i.e., the number of timesteps the network uses to make a single prediction)

  • The window size is a user-defined hyperparameter, called n_times in Braindecode. It mostly affects runtime (larger window sizes should be faster). As a rule of thumb, you can set it to two times the crop size.

  • Crop size and window size together define how many predictions the network makes per window: #window #crop + 1 = #predictions

Note

For cropped decoding, the above training setup is mathematically similar to sampling crops in your dataset, pushing them through the network and training directly on the individual crops. However, the if their position would be randomly selected, the crops would be less correlated in contrast to the neighbourhood crops selected from a window. At the same time, the above training setup is much faster as it avoids redundant computations by using dilated convolutions, see [2]. However, the two setups are only mathematically related in case (1) your network does not use any padding or only left padding and (2) your loss function leads to the same gradients when using the averaged output. The first is true for our shallow and deep ConvNet models and the second is true for the log-softmax outputs and negative log likelihood loss that is typically used for classification in PyTorch.

Loading and preprocessing the dataset#

Loading and preprocessing stays the same as in the Trialwise decoding tutorial.

from braindecode.datasets import MOABBDataset

subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])

from numpy import multiply

from braindecode.preprocessing import (
    Preprocessor,
    exponential_moving_standardize,
    preprocess,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),
    # Keep EEG sensors
    Preprocessor(lambda data: multiply(data, factor)),  # Convert from V to uV
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),
    # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,
        # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Transform the data
preprocess(dataset, preprocessors, n_jobs=-1)
MNE_DATA is not already configured. It will be set to default location in the home directory - /home/runner/mne_data
All datasets will be downloaded to this location, if anything is already downloaded, please move manually to this location
/home/runner/.local/lib/python3.10/site-packages/moabb/datasets/download.py:55: RuntimeWarning: Setting non-standard config type: "MNE_DATASETS_BNCI_PATH"
  set_config(key, get_config("MNE_DATA"))
/home/runner/.local/lib/python3.10/site-packages/urllib3/connectionpool.py:1064: InsecureRequestWarning: Unverified HTTPS request is being made to host 'lampx.tugraz.at'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings
  warnings.warn(

  0%|                                              | 0.00/44.1M [00:00<?, ?B/s]
  0%|                                     | 8.19k/44.1M [00:00<08:58, 81.8kB/s]
  0%|                                      | 32.8k/44.1M [00:00<04:11, 175kB/s]
  0%|                                      | 96.3k/44.1M [00:00<01:56, 378kB/s]
  0%|▏                                      | 209k/44.1M [00:00<01:06, 662kB/s]
  1%|▎                                     | 432k/44.1M [00:00<00:36, 1.21MB/s]
  2%|▊                                     | 889k/44.1M [00:00<00:18, 2.31MB/s]
  4%|█▌                                   | 1.80M/44.1M [00:00<00:09, 4.45MB/s]
  8%|███                                  | 3.62M/44.1M [00:00<00:04, 8.67MB/s]
 14%|█████▏                               | 6.18M/44.1M [00:00<00:02, 13.7MB/s]
 19%|███████▏                             | 8.49M/44.1M [00:01<00:02, 16.4MB/s]
 25%|█████████▎                           | 11.1M/44.1M [00:01<00:01, 19.2MB/s]
 31%|███████████▎                         | 13.5M/44.1M [00:01<00:01, 20.3MB/s]
 38%|█████████████▉                       | 16.6M/44.1M [00:01<00:01, 23.3MB/s]
 44%|████████████████▎                    | 19.5M/44.1M [00:01<00:00, 24.8MB/s]
 50%|██████████████████▋                  | 22.2M/44.1M [00:01<00:00, 25.4MB/s]
 57%|█████████████████████▏               | 25.2M/44.1M [00:01<00:00, 26.5MB/s]
 64%|███████████████████████▌             | 28.0M/44.1M [00:01<00:00, 26.7MB/s]
 70%|█████████████████████████▉           | 30.9M/44.1M [00:01<00:00, 27.2MB/s]
 77%|████████████████████████████▌        | 34.0M/44.1M [00:01<00:00, 28.2MB/s]
 84%|███████████████████████████████      | 37.0M/44.1M [00:02<00:00, 27.9MB/s]
 91%|█████████████████████████████████▌   | 39.9M/44.1M [00:02<00:00, 28.1MB/s]
 97%|███████████████████████████████████▉ | 42.8M/44.1M [00:02<00:00, 27.8MB/s]
  0%|                                              | 0.00/44.1M [00:00<?, ?B/s]
100%|██████████████████████████████████████| 44.1M/44.1M [00:00<00:00, 264GB/s]
/home/runner/.local/lib/python3.10/site-packages/urllib3/connectionpool.py:1064: InsecureRequestWarning: Unverified HTTPS request is being made to host 'lampx.tugraz.at'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings
  warnings.warn(

  0%|                                              | 0.00/42.3M [00:00<?, ?B/s]
  0%|                                     | 8.19k/42.3M [00:00<08:45, 80.4kB/s]
  0%|                                      | 32.8k/42.3M [00:00<04:03, 174kB/s]
  0%|                                      | 96.3k/42.3M [00:00<01:52, 375kB/s]
  0%|▏                                      | 209k/42.3M [00:00<01:03, 659kB/s]
  1%|▍                                     | 432k/42.3M [00:00<00:34, 1.20MB/s]
  2%|▊                                     | 889k/42.3M [00:00<00:18, 2.30MB/s]
  4%|█▌                                   | 1.80M/42.3M [00:00<00:09, 4.44MB/s]
  9%|███▏                                 | 3.62M/42.3M [00:00<00:04, 8.66MB/s]
 15%|█████▍                               | 6.19M/42.3M [00:00<00:02, 13.7MB/s]
 20%|███████▌                             | 8.58M/42.3M [00:01<00:02, 16.7MB/s]
 26%|█████████▋                           | 11.1M/42.3M [00:01<00:01, 19.1MB/s]
 32%|███████████▊                         | 13.6M/42.3M [00:01<00:01, 20.5MB/s]
 38%|█████████████▉                       | 16.0M/42.3M [00:01<00:01, 21.4MB/s]
 44%|████████████████▎                    | 18.7M/42.3M [00:01<00:01, 22.9MB/s]
 50%|██████████████████▍                  | 21.0M/42.3M [00:01<00:00, 22.8MB/s]
 56%|████████████████████▋                | 23.7M/42.3M [00:01<00:00, 23.6MB/s]
 62%|██████████████████████▊              | 26.1M/42.3M [00:01<00:00, 23.6MB/s]
 68%|█████████████████████████▏           | 28.7M/42.3M [00:01<00:00, 24.2MB/s]
 74%|███████████████████████████▎         | 31.2M/42.3M [00:01<00:00, 24.0MB/s]
 80%|█████████████████████████████▌       | 33.8M/42.3M [00:02<00:00, 24.5MB/s]
 86%|███████████████████████████████▋     | 36.3M/42.3M [00:02<00:00, 24.4MB/s]
 92%|██████████████████████████████████   | 38.9M/42.3M [00:02<00:00, 24.7MB/s]
 98%|████████████████████████████████████▎| 41.5M/42.3M [00:02<00:00, 24.7MB/s]
  0%|                                              | 0.00/42.3M [00:00<?, ?B/s]
100%|██████████████████████████████████████| 42.3M/42.3M [00:00<00:00, 258GB/s]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
48 events found on stim channel stim
Event IDs: [1 2 3 4]
/home/runner/work/braindecode/braindecode/braindecode/preprocessing/preprocess.py:69: UserWarning: Preprocessing choices with lambda functions cannot be saved.
  warn("Preprocessing choices with lambda functions cannot be saved.")

<braindecode.datasets.moabb.MOABBDataset object at 0x7fe7e9a46650>

Create model and compute windowing parameters#

In contrast to trialwise decoding, we first have to create the model before we can cut the dataset into windows. This is because we need to know the neural network parameters to know how large the sub-window stride should be.

We first choose the compute/input window size that will be fed to the network during training. This has to be larger than the networks the number of timesteps size and can otherwise be chosen for computational efficiency (see explanations in the beginning of this tutorial). Here we choose 1000 samples, which are 4 seconds for the 250 Hz sampling rate.

n_times = 1000

Now we create the model. To enable it to be used in cropped decoding efficiently, we manually set the length of the final convolution layer to some length that makes the number of timesteps of the ConvNet smaller than n_times (see final_conv_length=30 in the model definition).

import torch

from braindecode.models import ShallowFBCSPNet
from braindecode.util import set_random_seeds

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans from dataset
n_chans = dataset[0][0].shape[0]

model = ShallowFBCSPNet(
    n_chans,
    n_classes,
    n_times=n_times,
    final_conv_length=30,
)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    _ = model.cuda()
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 22, 1000]             [1, 4, 32]                --                        --
├─SafeLog (pool_nonlin_exp): 1-1         [1, 22, 1000]             [1, 22, 1000]             --                        --
├─Ensure4d (ensuredims): 1-2             [1, 22, 1000]             [1, 22, 1000, 1]          --                        --
├─Rearrange (dimshuffle): 1-3            [1, 22, 1000, 1]          [1, 1, 1000, 22]          --                        --
├─CombinedConv (conv_time_spat): 1-4     [1, 1, 1000, 22]          [1, 40, 976, 1]           36,240                    --
├─BatchNorm2d (bnorm): 1-5               [1, 40, 976, 1]           [1, 40, 976, 1]           80                        --
├─Expression (conv_nonlin_exp): 1-6      [1, 40, 976, 1]           [1, 40, 976, 1]           --                        --
├─AvgPool2d (pool): 1-7                  [1, 40, 976, 1]           [1, 40, 61, 1]            --                        [75, 1]
├─SafeLog (pool_nonlin_exp): 1-8         [1, 40, 61, 1]            [1, 40, 61, 1]            --                        --
├─Dropout (drop): 1-9                    [1, 40, 61, 1]            [1, 40, 61, 1]            --                        --
├─Sequential (final_layer): 1-10         [1, 40, 61, 1]            [1, 4, 32]                --                        --
│    └─Conv2d (conv_classifier): 2-1     [1, 40, 61, 1]            [1, 4, 32, 1]             4,804                     [30, 1]
│    └─Expression (squeeze): 2-2         [1, 4, 32, 1]             [1, 4, 32]                --                        --
============================================================================================================================================
Total params: 41,124
Trainable params: 41,124
Non-trainable params: 0
Total mult-adds (M): 0.15
============================================================================================================================================
Input size (MB): 0.09
Forward/backward pass size (MB): 0.31
Params size (MB): 0.02
Estimated Total Size (MB): 0.42
============================================================================================================================================

And now we transform model with strides to a model that outputs dense prediction, so we can use it to obtain predictions for all crops.

To know the models’ output shape without the last layer, we calculate the shape of model output for a dummy input.

Cut the data into windows#

In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the create_windows_from_events function.

from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])

# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    window_size_samples=n_times,
    window_stride_samples=n_preds_per_input,
    drop_last_window=False,
    preload=True,
)
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

Split the dataset#

This code is the same as in trialwise decoding.

splitted = windows_dataset.split("session")
train_set = splitted["0train"]  # Session train
valid_set = splitted["1test"]  # Session evaluation

Training#

In difference to trialwise decoding, we now should supply cropped=True to the EEGClassifier, and CroppedLoss as the criterion, as well as criterion__loss_function as the loss function applied to the meaned predictions.

Note

In this tutorial, we use some default parameters that we have found to work well for motor decoding, however we strongly encourage you to perform your own hyperparameter optimization using cross validation on your training data.

from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGClassifier
from braindecode.training import CroppedLoss

# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0

# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001

batch_size = 64
n_epochs = 2

clf = EEGClassifier(
    model,
    cropped=True,
    criterion=CroppedLoss,
    criterion__loss_function=torch.nn.functional.cross_entropy,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    callbacks=[
        "accuracy",
        ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
    ],
    device=device,
    classes=classes,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
_ = clf.fit(train_set, y=None, epochs=n_epochs)
  epoch    train_accuracy    train_loss    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  ----------------  ------------  ------  ------
      1            0.2569        1.2896            0.2847        1.6368  0.0006  8.1099
      2            0.5069        1.1231            0.4549        1.2402  0.0000  8.0370

Plot Results#

Note

Note that we drop further in the classification error and loss as in the trialwise decoding tutorial.

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D

# Extract loss and accuracy values for plotting from history object
results_columns = ["train_loss", "valid_loss", "train_accuracy", "valid_accuracy"]
df = pd.DataFrame(
    clf.history[:, results_columns],
    columns=results_columns,
    index=clf.history[:, "epoch"],
)

# get percent of misclass for better visual comparison to loss
df = df.assign(
    train_misclass=100 - 100 * df.train_accuracy,
    valid_misclass=100 - 100 * df.valid_accuracy,
)

fig, ax1 = plt.subplots(figsize=(8, 3))
df.loc[:, ["train_loss", "valid_loss"]].plot(
    ax=ax1, style=["-", ":"], marker="o", color="tab:blue", legend=False, fontsize=14
)

ax1.tick_params(axis="y", labelcolor="tab:blue", labelsize=14)
ax1.set_ylabel("Loss", color="tab:blue", fontsize=14)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

df.loc[:, ["train_misclass", "valid_misclass"]].plot(
    ax=ax2, style=["-", ":"], marker="o", color="tab:red", legend=False
)
ax2.tick_params(axis="y", labelcolor="tab:red", labelsize=14)
ax2.set_ylabel("Misclassification Rate [%]", color="tab:red", fontsize=14)
ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
ax1.set_xlabel("Epoch", fontsize=14)

# where some data has already been plotted to ax
handles = []
handles.append(
    Line2D([0], [0], color="black", linewidth=1, linestyle="-", label="Train")
)
handles.append(
    Line2D([0], [0], color="black", linewidth=1, linestyle=":", label="Valid")
)
plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
plt.tight_layout()
plot bcic iv 2a moabb cropped

Plot Confusion Matrix#

Generate a confusion matrix as in [2]

from sklearn.metrics import confusion_matrix

from braindecode.visualization import plot_confusion_matrix

# generate confusion matrices
# get the targets
y_true = valid_set.get_metadata().target
y_pred = clf.predict(valid_set)

# generating confusion matrix
confusion_mat = confusion_matrix(y_true, y_pred)

# add class labels
# label_dict is class_name : str -> i_class : int
label_dict = valid_set.datasets[0].window_kwargs[0][1]["mapping"]
# sort the labels by values (values are integer class labels)
labels = [k for k, v in sorted(label_dict.items(), key=lambda kv: kv[1])]

# plot the basic conf. matrix
plot_confusion_matrix(confusion_mat, class_names=labels)
plot bcic iv 2a moabb cropped
<Figure size 640x480 with 1 Axes>

References#

Total running time of the script: (0 minutes 40.585 seconds)

Estimated memory usage: 1565 MB

Gallery generated by Sphinx-Gallery