Multiple discrete targets with the TUH EEG Corpus

In this example, we showcase usage of multiple discrete targets per recording with the TUH EEG Corpus.

# Author: Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD (3-clause)

import mne
from torch.utils.data import DataLoader

from braindecode.datasets import TUH
from braindecode.preprocessing import create_fixed_length_windows

mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

If you want to try this code with the actual data, please delete the next section. We are required to mock some dataset functionality, since the data is not available at creation time of this example.

from braindecode.datasets.tuh import _TUHMock as TUH  # noqa F811

We start by creating a TUH dataset. Instead of just a str, we give it multiple strings as target names. Each of the strings has to exist as a column in the description DataFrame.

TUH_PATH = 'please insert actual path to data here'
tuh = TUH(
    path=TUH_PATH,
    recording_ids=None,
    target_name=('age', 'gender'),  # use both age and gender as decoding target
    preload=False,
    add_physician_reports=False,
)
tuh.description
path version year month day subject session segment age gender
0 tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001... v1.1.0 2003 2 5 58 1 0 0 M
1 tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004... v1.1.0 2014 9 30 9932 4 13 53 F
2 tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s0... v1.1.0 2014 12 14 12331 3 2 39 M
3 tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001... v1.1.0 2015 12 30 0 1 0 37 M
4 tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s0... v1.2.0 2016 1 15 14928 4 7 83 F


Iterating through the dataset gives x as ndarray(n_channels x 1) as well as the target as [age of the subject, gender of the subject]. Let’s look at the last example as it has more interesting age/gender labels (compare to the last row of the dataframe above).

x, y = tuh[-1]
print('x:', x)
print('y:', y)

Out:

x: [[-0.48388163]
 [-1.1033349 ]
 [-0.00548946]
 [-0.69145748]
 [-0.72950636]
 [-0.6732013 ]
 [-0.02884033]
 [-0.09684461]
 [ 0.66150905]
 [ 1.35850294]
 [-1.54706468]
 [ 0.81112458]
 [ 0.48616393]
 [ 0.26901556]
 [ 1.02706921]
 [-0.46342266]
 [-0.43525863]
 [-1.02658337]
 [-0.4584042 ]
 [ 0.45492769]
 [ 1.21383652]]
y: [83, 'F']

We will skip preprocessing steps for now, since it is not the aim of this example. Instead, we will directly create compute windows. We specify a mapping from genders ‘M’ and ‘F’ to integers, since this is required for decoding.

tuh_windows = create_fixed_length_windows(
    tuh,
    start_offset_samples=0,
    stop_offset_samples=None,
    window_size_samples=1000,
    window_stride_samples=1000,
    drop_last_window=False,
    mapping={'M': 0, 'F': 1},  # map non-digit targets
)
# store the number of windows required for loading later on
tuh_windows.set_description({
    "n_windows": [len(d) for d in tuh_windows.datasets]})

Iterating through the dataset gives x as ndarray(n_channels x 1000), y as [age, gender], and ind. Let’s look at the last example again.

x, y, ind = tuh_windows[-1]
print('x:', x)
print('y:', y)
print('ind:', ind)

Out:

x: [[ 5.6389427e-01 -2.1618271e+00 -9.9437243e-01 ... -6.4533629e-02
   3.9639103e-01 -4.8388162e-01]
 [ 1.1334016e-04 -2.4711089e-01  2.3326023e-01 ... -5.3718823e-01
   1.1165446e+00 -1.1033349e+00]
 [ 2.5976139e-01 -1.6312467e+00 -5.4536062e-01 ... -6.4550507e-01
  -3.1091759e-01 -5.4894560e-03]
 ...
 [ 2.1103388e-01  2.1207649e-01  1.0596663e+00 ...  1.1248783e+00
   2.2101052e+00 -4.5840421e-01]
 [ 2.6553613e-01 -1.0722766e+00 -1.8160485e+00 ... -4.7655761e-01
  -2.3370227e-02  4.5492768e-01]
 [ 6.8648207e-01  1.2309586e-01  3.9327252e-01 ...  9.7762001e-01
  -4.7603920e-01  1.2138366e+00]]
y: [83, 1]
ind: [3, 2600, 3600]

We give the dataset to a pytorch DataLoader, such that it can be used for model training.

dl = DataLoader(
    dataset=tuh_windows,
    batch_size=4,
)

Iterating through the DataLoader gives batch_X as tensor(4 x n_channels x 1000), batch_y as [tensor([4 x age of subject]), tensor([4 x gender of subject])], and batch_ind. We will iterate to the end to look at the last example again.

for batch_X, batch_y, batch_ind in dl:
    pass
print('batch_X:', batch_X)
print('batch_y:', batch_y)
print('batch_ind:', batch_ind)

Out:

batch_X: tensor([[[ 1.9264e-01, -2.8769e-01, -4.0477e-02,  ...,  4.3451e-01,
           2.3285e-01, -3.0400e-01],
         [-5.6241e-01, -2.4511e+00, -1.5853e+00,  ..., -1.4923e+00,
           1.1025e+00,  4.7152e-01],
         [ 4.5288e-01,  2.9770e-01, -7.7068e-03,  ...,  1.6793e-01,
          -5.4024e-01,  2.3311e+00],
         ...,
         [-1.4093e-01,  2.1644e-01, -7.2651e-02,  ..., -2.2531e+00,
          -2.3257e+00, -1.0198e-01],
         [ 1.7482e+00,  6.3536e-01, -1.3564e+00,  ..., -1.0846e-01,
           7.7717e-02,  5.7999e-01],
         [-5.4359e-01, -1.0553e+00,  2.1270e-01,  ...,  8.6473e-01,
          -1.0241e+00, -5.6435e-01]],

        [[ 2.7934e-01, -5.5462e-01, -2.3934e+00,  ..., -6.4195e-01,
           1.2517e+00,  1.4091e+00],
         [ 1.1977e+00,  7.7382e-01, -1.2499e+00,  ..., -5.1294e-01,
           1.3692e+00, -1.0125e+00],
         [-2.1263e+00, -5.8350e-02, -2.3486e-01,  ..., -6.6659e-01,
          -3.5822e-02,  8.5182e-01],
         ...,
         [-1.8836e+00, -5.2328e-01, -1.7144e+00,  ...,  1.9581e+00,
          -3.3173e-01,  5.9458e-01],
         [ 5.3573e-01,  4.7540e-01,  1.8706e+00,  ...,  1.1629e+00,
           7.8696e-01, -1.5714e+00],
         [ 5.6450e-01,  8.2211e-01,  3.2242e-01,  ..., -2.3119e+00,
          -7.1520e-01,  7.7749e-02]],

        [[ 1.4595e+00,  7.5736e-01,  4.0588e-02,  ...,  1.4255e+00,
           6.8046e-01,  5.0423e-01],
         [-8.8447e-01, -1.5425e-01,  7.6564e-01,  ...,  5.5104e-01,
          -8.6491e-01,  7.1067e-01],
         [ 3.9101e-01, -6.7435e-01,  3.1399e-01,  ..., -2.6413e-01,
           6.7261e-01, -4.9560e-01],
         ...,
         [ 1.2032e+00,  3.0923e-01,  4.1398e-01,  ..., -5.7762e-01,
          -4.7420e-02,  4.0071e-01],
         [ 3.6943e-01, -8.9819e-01,  1.0731e+00,  ...,  2.2911e-01,
           2.1890e-01,  2.2932e+00],
         [ 1.0741e+00,  1.6643e+00,  5.2559e-01,  ...,  1.2460e-01,
          -1.6045e+00,  2.4247e+00]],

        [[ 5.6389e-01, -2.1618e+00, -9.9437e-01,  ..., -6.4534e-02,
           3.9639e-01, -4.8388e-01],
         [ 1.1334e-04, -2.4711e-01,  2.3326e-01,  ..., -5.3719e-01,
           1.1165e+00, -1.1033e+00],
         [ 2.5976e-01, -1.6312e+00, -5.4536e-01,  ..., -6.4551e-01,
          -3.1092e-01, -5.4895e-03],
         ...,
         [ 2.1103e-01,  2.1208e-01,  1.0597e+00,  ...,  1.1249e+00,
           2.2101e+00, -4.5840e-01],
         [ 2.6554e-01, -1.0723e+00, -1.8160e+00,  ..., -4.7656e-01,
          -2.3370e-02,  4.5493e-01],
         [ 6.8648e-01,  1.2310e-01,  3.9327e-01,  ...,  9.7762e-01,
          -4.7604e-01,  1.2138e+00]]])
batch_y: [tensor([83, 83, 83, 83]), tensor([1, 1, 1, 1])]
batch_ind: [tensor([0, 1, 2, 3]), tensor([   0, 1000, 2000, 2600]), tensor([1000, 2000, 3000, 3600])]

Total running time of the script: ( 0 minutes 1.349 seconds)

Estimated memory usage: 19 MB

Gallery generated by Sphinx-Gallery