Split Dataset Example#

In this example, we aim to show multiple ways of how you can split your datasets for training, testing, and evaluating your models.

# Authors: Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD (3-clause)

from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import create_windows_from_events

Loading the dataset#

Firstly, we create a dataset using the braindecode MOABBDataset to load it fetched from MOABB. In this example, we’re using Dataset 2a from BCI Competition IV.

dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[1])

Splitting#

By description information#

The class MOABBDataset has a pandas DataFrame containing additional description of its internal datasets, which can be used to help splitting the data based on recording information, such as subject, session, and run of each trial.

dataset.description
subject session run
0 1 0train 0
1 1 0train 1
2 1 0train 2
3 1 0train 3
4 1 0train 4
5 1 0train 5
6 1 1test 0
7 1 1test 1
8 1 1test 2
9 1 1test 3
10 1 1test 4
11 1 1test 5


Here, we’re splitting the data based on different runs. The method split returns a dictionary with string keys corresponding to unique entries in the description DataFrame column.

splits = dataset.split("run")
print(splits)
splits["4"].description
{'0': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7d9d1c40>, '1': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7d9d0980>, '2': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7d9d1dc0>, '3': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7d9d0200>, '4': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7ebec1b2c0>, '5': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7ebec1ba70>}
subject session run
0 1 0train 4
1 1 1test 4


By row index#

Another way we can split the dataset is based on a list of integers corresponding to rows in the description. In this case, the returned dictionary will have ‘0’ as the only key.

splits = dataset.split([0, 1, 5])
print(splits)
splits["0"].description
{'0': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c43a540>}
subject session run
0 1 0train 0
1 1 0train 1
2 1 0train 5


However, if we want multiple splits based on indices, we can also define a list containing lists of integers. In this case, the dictionary will have string keys representing the index of the dataset split in the order of the given list of integers.

splits = dataset.split([[0, 1, 5], [2, 3, 4], [6, 7, 8, 9, 10, 11]])
print(splits)
splits["2"].description
{'0': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d817fe2a0>, '1': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c43ab70>, '2': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c43a7b0>}
subject session run
0 1 1test 0
1 1 1test 1
2 1 1test 2
3 1 1test 3
4 1 1test 4
5 1 1test 5


You can also name each split in the output dictionary by specifying the keys of each list of indexes in the input dictionary:

splits = dataset.split(
    {"train": [0, 1, 5], "valid": [2, 3, 4], "test": [6, 7, 8, 9, 10, 11]}
)
print(splits)
splits["test"].description
{'train': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7ebec1b920>, 'valid': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c43a420>, 'test': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c43b800>}
subject session run
0 1 1test 0
1 1 1test 1
2 1 1test 2
3 1 1test 3
4 1 1test 4
5 1 1test 5


Observation#

Similarly, we can split datasets after creating windows using the same methods.

windows = create_windows_from_events(
    dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0
)
# Splitting by different runs
print("Using description info")
splits = windows.split("run")
print(splits)
print()

# Splitting by row index
print("Splitting by row index")
splits = windows.split([4, 8])
print(splits)
print()

print("Multiple row index split")
splits = windows.split([[4, 8], [5, 9, 11]])
print(splits)
print()

# Specifying output's keys
print("Specifying keys")
splits = windows.split(dict(train=[4, 8], test=[5, 9, 11]))
print(splits)
Using description info
{'0': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d817fcda0>, '1': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c4393d0>, '2': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7eadf350a0>, '3': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7eadf37b00>, '4': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7d9d1c40>, '5': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7d9d2630>}

Splitting by row index
{'0': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c439070>}

Multiple row index split
{'0': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7eadf350a0>, '1': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c4393d0>}

Specifying keys
{'train': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c439070>, 'test': <braindecode.datasets.base.BaseConcatDataset object at 0x7f7d7c43b800>}

Total running time of the script: (0 minutes 3.958 seconds)

Estimated memory usage: 1334 MB

Gallery generated by Sphinx-Gallery