"""Dataset objects for some public datasets.
"""
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
# Lukas Gemein <l.gemein@gmail.com>
# Simon Brandt <simonbrandt@protonmail.com>
# David Sabbagh <dav.sabbagh@gmail.com>
# Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)
import pandas as pd
import mne
from .base import BaseDataset, BaseConcatDataset
from braindecode.util import _update_moabb_docstring
def _find_dataset_in_moabb(dataset_name, dataset_kwargs=None):
# soft dependency on moabb
from moabb.datasets.utils import dataset_list
for dataset in dataset_list:
if dataset_name == dataset.__name__:
# return an instance of the found dataset class
if dataset_kwargs is None:
return dataset()
else:
return dataset(**dataset_kwargs)
raise ValueError("'dataset_name' not found in moabb datasets")
def _fetch_and_unpack_moabb_data(dataset, subject_ids):
data = dataset.get_data(subject_ids)
raws, subject_ids, session_ids, run_ids = [], [], [], []
for subj_id, subj_data in data.items():
for sess_id, sess_data in subj_data.items():
for run_id, raw in sess_data.items():
# set annotation if empty
if len(raw.annotations) == 0:
annots = _annotations_from_moabb_stim_channel(raw, dataset)
raw.set_annotations(annots)
raws.append(raw)
subject_ids.append(subj_id)
session_ids.append(sess_id)
run_ids.append(run_id)
description = pd.DataFrame({
'subject': subject_ids,
'session': session_ids,
'run': run_ids
})
return raws, description
def _annotations_from_moabb_stim_channel(raw, dataset):
# find events from stim channel
events = mne.find_events(raw)
# get annotations from events
event_desc = {k: v for v, k in dataset.event_id.items()}
annots = mne.annotations_from_events(events, raw.info['sfreq'], event_desc)
# set trial on and offset given by moabb
onset, offset = dataset.interval
annots.onset += onset
annots.duration += offset - onset
return annots
def fetch_data_with_moabb(dataset_name, subject_ids, dataset_kwargs=None):
# ToDo: update path to where moabb downloads / looks for the data
"""Fetch data using moabb.
Parameters
----------
dataset_name: str
the name of a dataset included in moabb
subject_ids: list(int) | int
(list of) int of subject(s) to be fetched
dataset_kwargs: dict, optional
optional dictionary containing keyword arguments
to pass to the moabb dataset when instantiating it.
Returns
-------
raws: mne.Raw
info: pandas.DataFrame
"""
dataset = _find_dataset_in_moabb(dataset_name, dataset_kwargs)
subject_id = [subject_ids] if isinstance(subject_ids, int) else subject_ids
return _fetch_and_unpack_moabb_data(dataset, subject_id)
[docs]class MOABBDataset(BaseConcatDataset):
"""A class for moabb datasets.
Parameters
----------
dataset_name: str
name of dataset included in moabb to be fetched
subject_ids: list(int) | int | None
(list of) int of subject(s) to be fetched. If None, data of all
subjects is fetched.
dataset_kwargs: dict, optional
optional dictionary containing keyword arguments
to pass to the moabb dataset when instantiating it.
"""
def __init__(self, dataset_name, subject_ids, dataset_kwargs=None):
raws, description = fetch_data_with_moabb(dataset_name, subject_ids, dataset_kwargs)
all_base_ds = [BaseDataset(raw, row)
for raw, (_, row) in zip(raws, description.iterrows())]
super().__init__(all_base_ds)
[docs]class BNCI2014001(MOABBDataset):
doc = """See moabb.datasets.bnci.BNCI2014001
Parameters
----------
subject_ids: list(int) | int | None
(list of) int of subject(s) to be fetched. If None, data of all
subjects is fetched.
"""
try:
from moabb.datasets import BNCI2014001
__doc__ = _update_moabb_docstring(BNCI2014001, doc)
except ModuleNotFoundError:
pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
def __init__(self, subject_ids):
super().__init__("BNCI2014001", subject_ids=subject_ids)
[docs]class HGD(MOABBDataset):
doc = """See moabb.datasets.schirrmeister2017.Schirrmeister2017
Parameters
----------
subject_ids: list(int) | int | None
(list of) int of subject(s) to be fetched. If None, data of all
subjects is fetched.
"""
try:
from moabb.datasets import Schirrmeister2017
__doc__ = _update_moabb_docstring(Schirrmeister2017, doc)
except ModuleNotFoundError:
pass # keep moabb soft dependency, otherwise crash on loading of datasets.__init__.py
def __init__(self, subject_ids):
super().__init__("Schirrmeister2017", subject_ids=subject_ids)