Source code for braindecode.datasets.sleep_physionet

# Authors: Hubert Banville <hubert.jbanville@gmail.com>
#
# License: BSD (3-clause)


import os

import numpy as np
import pandas as pd
import mne
from mne.datasets.sleep_physionet.age import fetch_data

from .base import BaseDataset, BaseConcatDataset


[docs]class SleepPhysionet(BaseConcatDataset): """Sleep Physionet dataset. Sleep dataset from https://physionet.org/content/sleep-edfx/1.0.0/. Contains overnight recordings from 78 healthy subjects. See [MNE example](https://mne.tools/stable/auto_tutorials/sample-datasets/plot_sleep.html). Parameters ---------- subject_ids: list(int) | int | None (list of) int of subject(s) to be loaded. If None, load all available subjects. recording_ids: list(int) | None Recordings to load per subject (each subject except 13 has two recordings). Can be [1], [2] or [1, 2] (same as None). preload: bool If True, preload the data of the Raw objects. load_eeg_only: bool If True, only load the EEG channels and discard the others (EOG, EMG, temperature, respiration) to avoid resampling the other signals. crop_wake_mins: float Number of minutes of wake time to keep before the first sleep event and after the last sleep event. Used to reduce the imbalance in this dataset. Default of 30 mins. crop : None | tuple If not None crop the raw files (e.g. to use only the first 3h). Example: ``crop=(0, 3600*3)`` to keep only the first 3h. """ def __init__(self, subject_ids=None, recording_ids=None, preload=False, load_eeg_only=True, crop_wake_mins=30, crop=None): if subject_ids is None: subject_ids = range(83) if recording_ids is None: recording_ids = [1, 2] paths = fetch_data( subject_ids, recording=recording_ids, on_missing='warn') all_base_ds = list() for p in paths: raw, desc = self._load_raw( p[0], p[1], preload=preload, load_eeg_only=load_eeg_only, crop_wake_mins=crop_wake_mins, crop=crop) base_ds = BaseDataset(raw, desc) all_base_ds.append(base_ds) super().__init__(all_base_ds) @staticmethod def _load_raw(raw_fname, ann_fname, preload, load_eeg_only=True, crop_wake_mins=False, crop=None): ch_mapping = { 'EOG horizontal': 'eog', 'Resp oro-nasal': 'misc', 'EMG submental': 'misc', 'Temp rectal': 'misc', 'Event marker': 'misc' } exclude = list(ch_mapping.keys()) if load_eeg_only else () raw = mne.io.read_raw_edf(raw_fname, preload=preload, exclude=exclude) annots = mne.read_annotations(ann_fname) raw.set_annotations(annots, emit_warning=False) if crop_wake_mins > 0: # Find first and last sleep stages mask = [ x[-1] in ['1', '2', '3', '4', 'R'] for x in annots.description] sleep_event_inds = np.where(mask)[0] # Crop raw tmin = annots[int(sleep_event_inds[0])]['onset'] - crop_wake_mins * 60 tmax = annots[int(sleep_event_inds[-1])]['onset'] + crop_wake_mins * 60 raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1])) # Rename EEG channels ch_names = { i: i.replace('EEG ', '') for i in raw.ch_names if 'EEG' in i} raw.rename_channels(ch_names) if not load_eeg_only: raw.set_channel_types(ch_mapping) if crop is not None: raw.crop(*crop) basename = os.path.basename(raw_fname) subj_nb = int(basename[3:5]) sess_nb = int(basename[5]) desc = pd.Series({'subject': subj_nb, 'recording': sess_nb}, name='') return raw, desc