.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_sleep_staging_chambon2018.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_sleep_staging_chambon2018.py: Sleep staging on the Sleep Physionet dataset using Chambon2018 network ====================================================================== This tutorial shows how to train and test a sleep staging neural network with Braindecode. We adapt the time distributed approach of [1]_ to learn on sequences of EEG windows using the openly accessible Sleep Physionet dataset [2]_ [3]_. References ---------- .. [1] Chambon, S., Galtier, M., Arnal, P., Wainrib, G. and Gramfort, A. (2018)A Deep Learning Architecture for Temporal Sleep Stage Classification Using Multivariate and Multimodal Time Series. IEEE Trans. on Neural Systems and Rehabilitation Engineering 26: (758-769) .. [2] B Kemp, AH Zwinderman, B Tuk, HAC Kamphuisen, JJL OberyƩ. Analysis of a sleep-dependent neuronal feedback loop: the slow-wave microcontinuity of the EEG. IEEE-BME 47(9):1185-1194 (2000). .. [3] Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220 .. GENERATED FROM PYTHON SOURCE LINES 28-33 .. code-block:: default # Authors: Hubert Banville # # License: BSD (3-clause) .. GENERATED FROM PYTHON SOURCE LINES 34-48 Loading and preprocessing the dataset ------------------------------------- Loading ~~~~~~~ First, we load the data using the :class:`braindecode.datasets.sleep_physionet.SleepPhysionet` class. We load two recordings from two different individuals: we will use the first one to train our network and the second one to evaluate performance (as in the `MNE`_ sleep staging example). .. _MNE: https://mne.tools/stable/auto_tutorials/sample-datasets/plot_sleep.html .. GENERATED FROM PYTHON SOURCE LINES 48-57 .. code-block:: default from numbers import Integral from braindecode.datasets import SleepPhysionet subject_ids = [0, 1] dataset = SleepPhysionet( subject_ids=subject_ids, recording_ids=[2], crop_wake_mins=30) .. GENERATED FROM PYTHON SOURCE LINES 58-64 Preprocessing ~~~~~~~~~~~~~ Next, we preprocess the raw data. We convert the data to microvolts and apply a lowpass filter. We omit the downsampling step of [1]_ as the Sleep Physionet data is already sampled at a lower 100 Hz. .. GENERATED FROM PYTHON SOURCE LINES 64-78 .. code-block:: default from braindecode.preprocessing import preprocess, Preprocessor, scale high_cut_hz = 30 preprocessors = [ Preprocessor(scale, factor=1e6, apply_on_array=True), Preprocessor('filter', l_freq=None, h_freq=high_cut_hz) ] # Transform the data preprocess(dataset, preprocessors) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/share/miniconda/envs/braindecode/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function scale is deprecated; will be removed in 0.7.0. Use numpy.multiply instead. warnings.warn(msg, category=FutureWarning) .. GENERATED FROM PYTHON SOURCE LINES 79-83 Extract windows ~~~~~~~~~~~~~~~ We extract 30-s windows to be used in the classification task. .. GENERATED FROM PYTHON SOURCE LINES 83-111 .. code-block:: default from braindecode.preprocessing import create_windows_from_events mapping = { # We merge stages 3 and 4 following AASM standards. 'Sleep stage W': 0, 'Sleep stage 1': 1, 'Sleep stage 2': 2, 'Sleep stage 3': 3, 'Sleep stage 4': 3, 'Sleep stage R': 4 } window_size_s = 30 sfreq = 100 window_size_samples = window_size_s * sfreq windows_dataset = create_windows_from_events( dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0, window_size_samples=window_size_samples, window_stride_samples=window_size_samples, preload=True, mapping=mapping ) .. GENERATED FROM PYTHON SOURCE LINES 112-117 Window preprocessing ~~~~~~~~~~~~~~~~~~~~ We also preprocess the windows by applying channel-wise z-score normalization in each window. .. GENERATED FROM PYTHON SOURCE LINES 117-123 .. code-block:: default from sklearn.preprocessing import scale as standard_scale preprocess(windows_dataset, [Preprocessor(standard_scale, channel_wise=True)]) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 124-129 Split dataset into train and valid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We split the dataset into training and validation set taking every other subject as train or valid. .. GENERATED FROM PYTHON SOURCE LINES 129-134 .. code-block:: default split_ids = dict(train=subject_ids[::2], valid=subject_ids[1::2]) splits = windows_dataset.split(split_ids) train_set, valid_set = splits["train"], splits["valid"] .. GENERATED FROM PYTHON SOURCE LINES 135-148 Create sequence samplers ------------------------ Following the time distributed approach of [1]_, we need to provide our neural network with sequences of windows, such that the embeddings of multiple consecutive windows can be concatenated and provided to a final classifier. We can achieve this by defining Sampler objects that return sequences of window indices. To simplify the example, we train the whole model end-to-end on sequences, rather than using the two-step approach of [1]_ (i.e. training the feature extractor on single windows, then freezing its weights and training the classifier). .. GENERATED FROM PYTHON SOURCE LINES 148-162 .. code-block:: default import numpy as np from braindecode.samplers import SequenceSampler n_windows = 3 # Sequences of 3 consecutive windows n_windows_stride = 3 # Maximally overlapping sequences train_sampler = SequenceSampler(train_set.get_metadata(), n_windows, n_windows_stride) valid_sampler = SequenceSampler(valid_set.get_metadata(), n_windows, n_windows_stride) # Print number of examples per class print('Training examples: ', len(train_sampler)) print('Validation examples: ', len(valid_sampler)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Training examples: 372 Validation examples: 383 .. GENERATED FROM PYTHON SOURCE LINES 163-165 We also implement a transform to extract the label of the center window of a sequence to use it as target. .. GENERATED FROM PYTHON SOURCE LINES 165-177 .. code-block:: default # Use label of center window in the sequence def get_center_label(x): if isinstance(x, Integral): return x return x[np.ceil(len(x) / 2).astype(int)] if len(x) > 1 else x train_set.target_transform = get_center_label valid_set.target_transform = get_center_label .. GENERATED FROM PYTHON SOURCE LINES 178-182 Finally, since some sleep stages appear a lot more often than others (e.g. most of the night is spent in the N2 stage), the classes are imbalanced. To avoid overfitting on the more frequent classes, we compute weights that we will provide to the loss function when training. .. GENERATED FROM PYTHON SOURCE LINES 182-189 .. code-block:: default from sklearn.utils import compute_class_weight y_train = [train_set[idx][1] for idx in train_sampler] class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train) .. GENERATED FROM PYTHON SOURCE LINES 190-199 Create model ------------ We can now create the deep learning model. In this tutorial, we use the sleep staging architecture introduced in [1]_, which is a four-layer convolutional neural network. We use the time distributed version of the model, where the feature vectors of a sequence of windows are concatenated and passed to a linear layer for classification. .. GENERATED FROM PYTHON SOURCE LINES 199-243 .. code-block:: default import torch from torch import nn from braindecode.util import set_random_seeds from braindecode.models import SleepStagerChambon2018, TimeDistributed cuda = torch.cuda.is_available() # check if GPU is available device = 'cuda' if torch.cuda.is_available() else 'cpu' if cuda: torch.backends.cudnn.benchmark = True # Set random seed to be able to roughly reproduce results # Note that with cudnn benchmark set to True, GPU indeterminism # may still make results substantially different between runs. # To obtain more consistent results at the cost of increased computation time, # you can set `cudnn_benchmark=False` in `set_random_seeds` # or remove `torch.backends.cudnn.benchmark = True` set_random_seeds(seed=31, cuda=cuda) n_classes = 5 # Extract number of channels and time steps from dataset n_channels, input_size_samples = train_set[0][0].shape feat_extractor = SleepStagerChambon2018( n_channels, sfreq, n_classes=n_classes, input_size_s=input_size_samples / sfreq, return_feats=True ) model = nn.Sequential( TimeDistributed(feat_extractor), # apply model on each 30-s window nn.Sequential( # apply linear layer on concatenated feature vectors nn.Flatten(start_dim=1), nn.Dropout(0.5), nn.Linear(feat_extractor.len_last_layer * n_windows, n_classes) ) ) # Send model to GPU if cuda: model.cuda() .. GENERATED FROM PYTHON SOURCE LINES 244-260 Training -------- We can now train our network. :class:`braindecode.EEGClassifier` is a braindecode object that is responsible for managing the training of neural networks. It inherits from :class:`skorch.NeuralNetClassifier`, so the training logic is the same as in `Skorch `__. .. note:: We use different hyperparameters from [1]_, as these hyperparameters were optimized on a different dataset (MASS SS3) and with a different number of recordings. Generally speaking, it is recommended to perform hyperparameter optimization if reusing this code on a different dataset or with more recordings. .. GENERATED FROM PYTHON SOURCE LINES 260-299 .. code-block:: default from skorch.helper import predefined_split from skorch.callbacks import EpochScoring from braindecode import EEGClassifier lr = 1e-3 batch_size = 32 n_epochs = 10 train_bal_acc = EpochScoring( scoring='balanced_accuracy', on_train=True, name='train_bal_acc', lower_is_better=False) valid_bal_acc = EpochScoring( scoring='balanced_accuracy', on_train=False, name='valid_bal_acc', lower_is_better=False) callbacks = [ ('train_bal_acc', train_bal_acc), ('valid_bal_acc', valid_bal_acc) ] clf = EEGClassifier( model, criterion=torch.nn.CrossEntropyLoss, criterion__weight=torch.Tensor(class_weights).to(device), optimizer=torch.optim.Adam, iterator_train__shuffle=False, iterator_train__sampler=train_sampler, iterator_valid__sampler=valid_sampler, train_split=predefined_split(valid_set), # using valid_set for validation optimizer__lr=lr, batch_size=batch_size, callbacks=callbacks, device=device ) # Model training for a specified number of epochs. `y` is None as it is already # supplied in the dataset. clf.fit(train_set, y=None, epochs=n_epochs) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none epoch train_bal_acc train_loss valid_bal_acc valid_loss dur ------- --------------- ------------ --------------- ------------ ------ 1 0.1199 1.6479 0.1962 1.6122 3.6697 2 0.1818 1.6060 0.2113 1.6092 3.2308 3 0.2338 1.5971 0.2000 1.6173 3.2405 4 0.2100 1.5664 0.2000 1.6141 3.3304 5 0.2300 1.5202 0.2000 1.6408 3.3049 6 0.2200 1.4405 0.2000 1.6752 3.2532 7 0.2563 1.3843 0.2000 1.7847 3.2445 8 0.2620 1.3200 0.2056 1.8087 3.3092 9 0.2818 1.2864 0.2113 1.8648 3.2535 10 0.3058 1.2526 0.2258 1.8307 3.2599 [initialized]( module_=Sequential( (0): TimeDistributed( (module): SleepStagerChambon2018( (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1)) (feature_extractor): Sequential( (0): Conv2d(1, 8, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25)) (1): Identity() (2): ReLU() (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False) (4): Conv2d(8, 8, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25)) (5): Identity() (6): ReLU() (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False) ) ) ) (1): Sequential( (0): Flatten(start_dim=1, end_dim=-1) (1): Dropout(p=0.5, inplace=False) (2): Linear(in_features=816, out_features=5, bias=True) ) ), ) .. GENERATED FROM PYTHON SOURCE LINES 300-306 Plot results ------------ We use the history stored by Skorch during training to plot the performance of the model throughout training. Specifically, we plot the loss and the balanced balanced accuracy for the training and validation sets. .. GENERATED FROM PYTHON SOURCE LINES 306-323 .. code-block:: default import matplotlib.pyplot as plt import pandas as pd # Extract loss and balanced accuracy values for plotting from history object df = pd.DataFrame(clf.history.to_list()) df.index.name = "Epoch" fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 7), sharex=True) df[['train_loss', 'valid_loss']].plot(color=['r', 'b'], ax=ax1) df[['train_bal_acc', 'valid_bal_acc']].plot(color=['r', 'b'], ax=ax2) ax1.set_ylabel('Loss') ax2.set_ylabel('Balanced accuracy') ax1.legend(['Train', 'Valid']) ax2.legend(['Train', 'Valid']) fig.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_sleep_staging_chambon2018_001.png :alt: plot sleep staging chambon2018 :srcset: /auto_examples/images/sphx_glr_plot_sleep_staging_chambon2018_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 324-326 Finally, we also display the confusion matrix and classification report: .. GENERATED FROM PYTHON SOURCE LINES 326-336 .. code-block:: default from sklearn.metrics import confusion_matrix, classification_report y_true = [valid_set[[i]][1][0] for i in range(len(valid_sampler))] y_pred = clf.predict(valid_set) print(confusion_matrix(y_true, y_pred)) print(classification_report(y_true, y_pred)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [[ 0 0 3 58 0] [ 0 0 9 12 0] [ 1 0 23 170 0] [ 0 0 7 68 0] [ 0 0 6 26 0]] /usr/share/miniconda/envs/braindecode/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/share/miniconda/envs/braindecode/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/share/miniconda/envs/braindecode/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) precision recall f1-score support 0 0.00 0.00 0.00 61 1 0.00 0.00 0.00 21 2 0.48 0.12 0.19 194 3 0.20 0.91 0.33 75 4 0.00 0.00 0.00 32 accuracy 0.24 383 macro avg 0.14 0.21 0.10 383 weighted avg 0.28 0.24 0.16 383 .. GENERATED FROM PYTHON SOURCE LINES 337-347 Our model was able to learn despite the low amount of data that was available (only two recordings in this example) and reached a balanced accuracy of about 36% in a 5-class classification task (chance-level = 20%) on held-out data. .. note:: To further improve performance, more recordings should be included in the training set, and hyperparameters should be selected accordingly. Increasing the sequence length was also shown in [1]_ to help improve performance, especially when few EEG channels are available. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 43.225 seconds) **Estimated memory usage:** 410 MB .. _sphx_glr_download_auto_examples_plot_sleep_staging_chambon2018.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sleep_staging_chambon2018.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sleep_staging_chambon2018.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_