"""
Dataset classes for the Temple University Hospital (TUH) EEG Corpus and the
TUH Abnormal EEG Corpus.
"""
# Authors: Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD (3-clause)
import re
import os
import glob
from unittest import mock
from datetime import datetime, timezone
import pandas as pd
import numpy as np
import mne
from joblib import Parallel, delayed
from .base import BaseDataset, BaseConcatDataset
[docs]class TUH(BaseConcatDataset):
"""Temple University Hospital (TUH) EEG Corpus
(www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tueg).
Parameters
----------
path: str
Parent directory of the dataset.
recording_ids: list(int) | int
A (list of) int of recording id(s) to be read (order matters and will
overwrite default chronological order, e.g. if recording_ids=[1,0],
then the first recording returned by this class will be chronologically
later then the second recording. Provide recording_ids in ascending
order to preserve chronological order.).
target_name: str
Can be 'gender', or 'age'.
preload: bool
If True, preload the data of the Raw objects.
add_physician_reports: bool
If True, the physician reports will be read from disk and added to the
description.
n_jobs: int
Number of jobs to be used to read files in parallel.
"""
def __init__(self, path, recording_ids=None, target_name=None,
preload=False, add_physician_reports=False, n_jobs=1):
# create an index of all files and gather easily accessible info
# without actually touching the files
file_paths = glob.glob(os.path.join(path, '**/*.edf'), recursive=True)
descriptions = _create_chronological_description(file_paths)
# limit to specified recording ids before doing slow stuff
if recording_ids is not None:
descriptions = descriptions[recording_ids]
# this is the second loop (slow)
# create datasets gathering more info about the files touching them
# reading the raws and potentially preloading the data
# disable joblib for tests. mocking seems to fail otherwise
if n_jobs == 1:
base_datasets = [self._create_dataset(
descriptions[i], target_name, preload, add_physician_reports)
for i in descriptions.columns]
else:
base_datasets = Parallel(n_jobs)(delayed(
self._create_dataset)(
descriptions[i], target_name, preload, add_physician_reports
) for i in descriptions.columns)
super().__init__(base_datasets)
@staticmethod
def _create_dataset(description, target_name, preload,
add_physician_reports):
file_path = description.loc['path']
# parse age and gender information from EDF header
age, gender = _parse_age_and_gender_from_edf_header(file_path)
raw = mne.io.read_raw_edf(file_path, preload=preload)
# Use recording date from path as EDF header is sometimes wrong
meas_date = datetime(1, 1, 1, tzinfo=timezone.utc) \
if raw.info['meas_date'] is None else raw.info['meas_date']
raw.set_meas_date(meas_date.replace(
*description[['year', 'month', 'day']]))
# read info relevant for preprocessing from raw without loading it
d = {
'age': int(age),
'gender': gender,
}
if add_physician_reports:
physician_report = _read_physician_report(file_path)
d['report'] = physician_report
additional_description = pd.Series(d)
description = pd.concat([description, additional_description])
base_dataset = BaseDataset(raw, description,
target_name=target_name)
return base_dataset
def _create_chronological_description(file_paths):
# this is the first loop (fast)
descriptions = []
for file_path in file_paths:
description = _parse_description_from_file_path(file_path)
descriptions.append(pd.Series(description))
descriptions = pd.concat(descriptions, axis=1)
# order descriptions chronologically
descriptions.sort_values(
["year", "month", "day", "subject", "session", "segment"],
axis=1, inplace=True)
# https://stackoverflow.com/questions/42284617/reset-column-index-pandas
descriptions = descriptions.T.reset_index(drop=True).T
return descriptions
def _parse_description_from_file_path(file_path):
# stackoverflow.com/questions/3167154/how-to-split-a-dos-path-into-its-components-in-python # noqa
file_path = os.path.normpath(file_path)
tokens = file_path.split(os.sep)
# expect file paths as tuh_eeg/version/file_type/reference/data_split/
# subject/recording session/file
# e.g. tuh_eeg/v1.1.0/edf/01_tcp_ar/027/00002729/
# s001_2006_04_12/00002729_s001.edf
version = tokens[-7]
year, month, day = tokens[-2].split('_')[1:]
subject_id = tokens[-3]
session = tokens[-2].split('_')[0]
segment = tokens[-1].split('_')[-1].split('.')[-2]
return {
'path': file_path,
'version': version,
'year': int(year),
'month': int(month),
'day': int(day),
'subject': int(subject_id),
'session': int(session[1:]),
'segment': int(segment[1:]),
}
def _read_physician_report(file_path):
directory = os.path.dirname(file_path)
txt_file = glob.glob(os.path.join(directory, '**/*.txt'), recursive=True)
# check that there is at most one txt file in the same directory
assert len(txt_file) in [0, 1]
report = ''
if txt_file:
txt_file = txt_file[0]
# somewhere in the corpus, encoding apparently changed
# first try to read as utf-8, if it does not work use latin-1
try:
with open(txt_file, 'r', encoding='utf-8') as f:
report = f.read()
except UnicodeDecodeError:
with open(txt_file, 'r', encoding='latin-1') as f:
report = f.read()
return report
def _read_edf_header(file_path):
f = open(file_path, "rb")
header = f.read(88)
f.close()
return header
def _parse_age_and_gender_from_edf_header(file_path):
header = _read_edf_header(file_path)
# bytes 8 to 88 contain ascii local patient identification
# see https://www.teuniz.net/edfbrowser/edf%20format%20description.html
patient_id = header[8:].decode("ascii")
age = -1
found_age = re.findall(r"Age:(\d+)", patient_id)
if len(found_age) == 1:
age = int(found_age[0])
gender = "X"
found_gender = re.findall(r"\s([F|M])\s", patient_id)
if len(found_gender) == 1:
gender = found_gender[0]
return age, gender
[docs]class TUHAbnormal(TUH):
"""Temple University Hospital (TUH) Abnormal EEG Corpus.
see www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tuab
Parameters
----------
path: str
Parent directory of the dataset.
recording_ids: list(int) | int
A (list of) int of recording id(s) to be read (order matters and will
overwrite default chronological order, e.g. if recording_ids=[1,0],
then the first recording returned by this class will be chronologically
later then the second recording. Provide recording_ids in ascending
order to preserve chronological order.).
target_name: str
Can be 'pathological', 'gender', or 'age'.
preload: bool
If True, preload the data of the Raw objects.
add_physician_reports: bool
If True, the physician reports will be read from disk and added to the
description.
"""
def __init__(self, path, recording_ids=None, target_name='pathological',
preload=False, add_physician_reports=False, n_jobs=1):
super().__init__(path=path, recording_ids=recording_ids,
preload=preload, target_name=target_name,
add_physician_reports=add_physician_reports,
n_jobs=n_jobs)
additional_descriptions = []
for file_path in self.description.path:
additional_description = (
self._parse_additional_description_from_file_path(file_path))
additional_descriptions.append(additional_description)
additional_descriptions = pd.DataFrame(additional_descriptions)
self.set_description(additional_descriptions, overwrite=True)
@staticmethod
def _parse_additional_description_from_file_path(file_path):
file_path = os.path.normpath(file_path)
tokens = file_path.split(os.sep)
# expect paths as version/file type/data_split/pathology status/
# reference/subset/subject/recording session/file
# e.g. v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
# s004_2013_08_15/00000021_s004_t000.edf
assert ('abnormal' in tokens or 'normal' in tokens), (
'No pathology labels found.')
assert ('train' in tokens or 'eval' in tokens), (
'No train or eval set information found.')
return {
'version': tokens[-9],
'train': 'train' in tokens,
'pathological': 'abnormal' in tokens,
}
def _fake_raw(*args, **kwargs):
sfreq = 10
ch_names = [
'EEG A1-REF', 'EEG A2-REF',
'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF',
'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF',
'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',
'EEG T6-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF']
duration_min = 6
data = np.random.randn(len(ch_names), duration_min * sfreq * 60)
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
raw = mne.io.RawArray(data=data, info=info)
return raw
def _get_header(*args, **kwargs):
all_paths = {**_TUH_EEG_PATHS, **_TUH_EEG_ABNORMAL_PATHS}
return all_paths[args[0]]
_TUH_EEG_PATHS = {
# These are actual file paths and edf headers from the TUH EEG Corpus (v1.1.0 and v1.2.0)
'tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001_2015_12_30/00000000_s001_t000.edf': b'0 00000000 M 01-JAN-1978 00000000 Age:37 ', # noqa E501
'tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004_2014_09_30/00009932_s004_t013.edf': b'0 00009932 F 01-JAN-1961 00009932 Age:53 ', # noqa E501
'tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001_2003_02_05/00000058_s001_t000.edf': b'0 00000058 M 01-JAN-2003 00000058 Age:0.0109 ', # noqa E501
'tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s003_2014_12_14/00012331_s003_t002.edf': b'0 00012331 M 01-JAN-1975 00012331 Age:39 ', # noqa E501
'tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s004_2016_01_15/00014928_s004_t007.edf': b'0 00014928 F 01-JAN-1933 00014928 Age:83 ', # noqa E501
}
_TUH_EEG_ABNORMAL_PATHS = {
# these are actual file paths and edf headers from TUH Abnormal EEG Corpus (v2.0.0)
'tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/078/00007871/s001_2011_07_05/00007871_s001_t001.edf': b'0 00007871 F 01-JAN-1988 00007871 Age:23 ', # noqa E501
'tuh_abnormal_eeg/v2.0.0/edf/train/normal/01_tcp_ar/097/00009777/s001_2012_09_17/00009777_s001_t000.edf': b'0 00009777 M 01-JAN-1986 00009777 Age:26 ', # noqa E501
'tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/083/00008393/s002_2012_02_21/00008393_s002_t000.edf': b'0 00008393 M 01-JAN-1960 00008393 Age:52 ', # noqa E501
'tuh_abnormal_eeg/v2.0.0/edf/train/abnormal/01_tcp_ar/012/00001200/s003_2010_12_06/00001200_s003_t000.edf': b'0 00001200 M 01-JAN-1963 00001200 Age:47 ', # noqa E501
'tuh_abnormal_eeg/v2.0.0/edf/eval/abnormal/01_tcp_ar/059/00005932/s004_2013_03_14/00005932_s004_t000.edf': b'0 00005932 M 01-JAN-1963 00005932 Age:50 ', # noqa E501
}
class _TUHMock(TUH):
"""Mocked class for testing and examples."""
@mock.patch('glob.glob', return_value=_TUH_EEG_PATHS.keys())
@mock.patch('mne.io.read_raw_edf', new=_fake_raw)
@mock.patch('braindecode.datasets.tuh._read_edf_header',
new=_get_header)
def __init__(self, mock_glob, path, recording_ids=None, target_name=None,
preload=False, add_physician_reports=False, n_jobs=1):
super().__init__(path=path, recording_ids=recording_ids,
target_name=target_name, preload=preload,
add_physician_reports=add_physician_reports,
n_jobs=n_jobs)
class _TUHAbnormalMock(TUHAbnormal):
"""Mocked class for testing and examples."""
@mock.patch('glob.glob', return_value=_TUH_EEG_ABNORMAL_PATHS.keys())
@mock.patch('mne.io.read_raw_edf', new=_fake_raw)
@mock.patch('braindecode.datasets.tuh._read_edf_header',
new=_get_header)
@mock.patch('braindecode.datasets.tuh._read_physician_report',
return_value='simple_test')
def __init__(self, mock_glob, mock_report, path, recording_ids=None,
target_name='pathological', preload=False,
add_physician_reports=False, n_jobs=1):
super().__init__(path=path, recording_ids=recording_ids,
target_name=target_name, preload=preload,
add_physician_reports=add_physician_reports,
n_jobs=n_jobs)