Source code for braindecode.datasets.bcicomp

# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
#          Mohammed Fattouh <mo.fattouh@gmail.com>
#
# License: BSD (3-clause)

import glob
import os
import os.path as osp
from os import remove
from shutil import unpack_archive

import mne
import numpy as np
from mne.utils import verbose
from scipy.io import loadmat

from braindecode.datasets import BaseDataset, BaseConcatDataset

DATASET_URL = 'https://stacks.stanford.edu/file/druid:zk881ps0522/' \
              'BCI_Competion4_dataset4_data_fingerflexions.zip'


[docs]class BCICompetitionIVDataset4(BaseConcatDataset): """BCI competition IV dataset 4. Contains ECoG recordings for three patients moving fingers during the experiment. Targets correspond to the time courses of the flexion of each of five fingers. See http://www.bbci.de/competition/iv/desc_4.pdf and http://www.bbci.de/competition/iv/ for the dataset and competition description. ECoG library containing the dataset: https://searchworks.stanford.edu/view/zk881ps0522 Notes ----- When using this dataset please cite [1]_ . Parameters ---------- subject_ids : list(int) | int | None (list of) int of subject(s) to be loaded. If None, load all available subjects. Should be in range 1-3. References ---------- .. [1] Miller, Kai J. "A library of human electrocorticographic data and analyses." Nature human behaviour 3, no. 11 (2019): 1225-1235. https://doi.org/10.1038/s41562-019-0678-3 """ possible_subjects = [1, 2, 3] def __init__(self, subject_ids=None): data_path = self.download() if isinstance(subject_ids, int): subject_ids = [subject_ids] if subject_ids is None: subject_ids = self.possible_subjects self._validate_subjects(subject_ids) files_list = [f'{data_path}/sub{i}_comp.mat' for i in subject_ids] datasets = [] for file_path in files_list: raw_train, raw_test = self._load_data_to_mne(file_path) desc_train = dict( subject=file_path.split('/')[-1].split('sub')[1][0], file_name=file_path.split('/')[-1], session='train' ) desc_test = dict( subject=file_path.split('/')[-1].split('sub')[1][0], file_name=file_path.split('/')[-1], session='test' ) datasets.append(BaseDataset(raw_train, description=desc_train)) datasets.append(BaseDataset(raw_test, description=desc_test)) super().__init__(datasets)
[docs] @staticmethod def download(path=None, force_update=False, verbose=None): """Download the dataset. Parameters ---------- path (None | str) – Location of where to look for the data storing location. If None, the environment variable or config parameter MNE_DATASETS_(dataset)_PATH is used. If it doesn’t exist, the “~/mne_data” directory is used. If the dataset is not found under the given path, the data will be automatically downloaded to the specified folder. force_update (bool) – Force update of the dataset even if a local copy exists. verbose (bool, str, int, or None) – If not None, override default verbose level (see mne.verbose()) Returns ------- """ signature = 'BCICompetitionIVDataset4' folder_name = 'BCI_Competion4_dataset4_data_fingerflexions' # Check if the dataset already exists (unpacked). We have to do that manually # because we are removing .zip file from disk to save disk space. from moabb.datasets.download import get_dataset_path # keep soft depenency path = get_dataset_path(signature, path) key_dest = "MNE-{:s}-data".format(signature.lower()) # We do not use mne _url_to_local_path due to ':' in the url that causes problems on Windows destination = osp.join(path, key_dest, folder_name) if len(list(glob.glob(osp.join(destination, '*.mat')))) == 6: return destination data_path = _data_dl(DATASET_URL, osp.join(destination, folder_name, signature), force_update=force_update) unpack_archive(data_path, osp.dirname(destination)) # removes .zip file that the data was unpacked from remove(data_path) return destination
@staticmethod def _prepare_targets(upsampled_targets, targets_stride): original_targets = np.full_like(upsampled_targets, np.nan) original_targets[::targets_stride] = upsampled_targets[::targets_stride] return original_targets def _load_data_to_mne(self, file_path): data = loadmat(file_path) test_labels = loadmat(file_path.replace('comp.mat', 'testlabels.mat')) train_data = data['train_data'] test_data = data['test_data'] upsampled_train_targets = data['train_dg'] upsampled_test_targets = test_labels['test_dg'] signal_sfreq = 1000 original_target_sfreq = 25 targets_stride = int(signal_sfreq / original_target_sfreq) original_targets = self._prepare_targets(upsampled_train_targets, targets_stride) original_test_targets = self._prepare_targets(upsampled_test_targets, targets_stride) ch_names = [f'{i}' for i in range(train_data.shape[1])] ch_names += [f'target_{i}' for i in range(original_targets.shape[1])] ch_types = ['ecog' for _ in range(train_data.shape[1])] ch_types += ['misc' for _ in range(original_targets.shape[1])] info = mne.create_info(sfreq=signal_sfreq, ch_names=ch_names, ch_types=ch_types) info['temp'] = dict(target_sfreq=original_target_sfreq) train_data = np.concatenate([train_data, original_targets], axis=1) test_data = np.concatenate([test_data, original_test_targets], axis=1) raw_train = mne.io.RawArray(train_data.T, info=info) raw_test = mne.io.RawArray(test_data.T, info=info) # TODO: show how to resample targets return raw_train, raw_test def _validate_subjects(self, subject_ids): if isinstance(subject_ids, (list, tuple)): if not all((subject in self.possible_subjects for subject in subject_ids)): raise ValueError( f'Wrong subject_ids parameter. Possible values: {self.possible_subjects}. ' f'Provided {subject_ids}.' ) else: raise ValueError( 'Wrong subject_ids format. Expected types: None, list, tuple, int.' )
@verbose def _data_dl(url, destination, force_update=False, verbose=None): # Code taken from moabb due to problem with ':' occurring in path # On Windows ':' is a forbidden in folder name # moabb/datasets/download.py from pooch import file_hash, retrieve # keep soft depenency if not osp.isfile(destination) or force_update: if osp.isfile(destination): os.remove(destination) if not osp.isdir(osp.dirname(destination)): os.makedirs(osp.dirname(destination)) known_hash = None else: known_hash = file_hash(destination) data_path = retrieve( url, known_hash, fname=osp.basename(url), path=osp.dirname(destination) ) return data_path