Sleep staging on the Sleep Physionet dataset

This tutorial shows how to train and test a sleep staging neural network with Braindecode. We follow the approach of 1 on the openly accessible Sleep Physionet dataset 1 2.

References

1(1,2,3,4,5,6)

Chambon, S., Galtier, M., Arnal, P., Wainrib, G. and Gramfort, A. (2018)A Deep Learning Architecture for Temporal Sleep Stage Classification Using Multivariate and Multimodal Time Series. IEEE Trans. on Neural Systems and Rehabilitation Engineering 26: (758-769)

2

B Kemp, AH Zwinderman, B Tuk, HAC Kamphuisen, JJL Oberyé. Analysis of a sleep-dependent neuronal feedback loop: the slow-wave microcontinuity of the EEG. IEEE-BME 47(9):1185-1194 (2000).

3

Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220

# Authors: Hubert Banville <hubert.jbanville@gmail.com>
#
# License: BSD (3-clause)

Loading and preprocessing the dataset

Loading

First, we load the data using the braindecode.datasets.sleep_physionet.SleepPhysionet class. We load two recordings from two different individuals: we will use the first one to train our network and the second one to evaluate performance (as in the MNE sleep staging example).

Note

To load your own datasets either via MNE or from preprocessed X/y numpy arrays, see the MNE Dataset Tutorial and the Numpy Dataset Tutorial.

from braindecode.datasets.sleep_physionet import SleepPhysionet

dataset = SleepPhysionet(
    subject_ids=[0, 1], recording_ids=[1], crop_wake_mins=30)

Out:

Using default location ~/mne_data for PHYSIONET_SLEEP...
Downloading https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette//SC4001E0-PSG.edf (46.1 MB)

  0%|          | Downloading : 0.00/46.1M [00:00<?,        ?B/s]
  0%|          | Downloading : 56.0k/46.1M [00:00<00:14,    3.33MB/s]
  0%|          | Downloading : 152k/46.1M [00:00<00:14,    3.32MB/s]
  1%|          | Downloading : 280k/46.1M [00:00<00:14,    3.41MB/s]
  1%|1         | Downloading : 536k/46.1M [00:00<00:13,    3.55MB/s]
  2%|2         | Downloading : 1.02M/46.1M [00:00<00:12,    3.72MB/s]
  4%|4         | Downloading : 2.02M/46.1M [00:00<00:11,    3.89MB/s]
  7%|6         | Downloading : 3.02M/46.1M [00:00<00:11,    4.07MB/s]
  8%|7         | Downloading : 3.52M/46.1M [00:00<00:10,    4.25MB/s]
 10%|9         | Downloading : 4.52M/46.1M [00:00<00:09,    4.45MB/s]
 13%|#3        | Downloading : 6.02M/46.1M [00:00<00:09,    4.66MB/s]
 15%|#5        | Downloading : 7.02M/46.1M [00:00<00:08,    4.88MB/s]
 17%|#7        | Downloading : 8.02M/46.1M [00:00<00:07,    5.10MB/s]
 20%|#9        | Downloading : 9.02M/46.1M [00:00<00:07,    5.33MB/s]
 22%|##1       | Downloading : 10.0M/46.1M [00:00<00:06,    5.58MB/s]
 24%|##3       | Downloading : 11.0M/46.1M [00:00<00:06,    5.84MB/s]
 26%|##6       | Downloading : 12.0M/46.1M [00:00<00:05,    6.11MB/s]
 28%|##8       | Downloading : 13.0M/46.1M [00:00<00:05,    6.34MB/s]
 33%|###2      | Downloading : 15.0M/46.1M [00:00<00:04,    6.63MB/s]
 35%|###4      | Downloading : 16.0M/46.1M [00:00<00:04,    6.94MB/s]
 37%|###6      | Downloading : 17.0M/46.1M [00:00<00:04,    7.24MB/s]
 39%|###9      | Downloading : 18.0M/46.1M [00:00<00:03,    7.57MB/s]
 43%|####3     | Downloading : 20.0M/46.1M [00:00<00:03,    7.93MB/s]
 50%|####9     | Downloading : 23.0M/46.1M [00:00<00:02,    8.29MB/s]
 54%|#####4    | Downloading : 25.0M/46.1M [00:00<00:02,    8.68MB/s]
 59%|#####8    | Downloading : 27.0M/46.1M [00:00<00:02,    9.09MB/s]
 63%|######2   | Downloading : 29.0M/46.1M [00:00<00:01,    9.51MB/s]
 67%|######7   | Downloading : 31.0M/46.1M [00:00<00:01,    9.95MB/s]
 72%|#######1  | Downloading : 33.0M/46.1M [00:00<00:01,    10.4MB/s]
 76%|#######5  | Downloading : 35.0M/46.1M [00:00<00:01,    10.9MB/s]
 80%|########  | Downloading : 37.0M/46.1M [00:00<00:00,    11.4MB/s]
 85%|########4 | Downloading : 39.0M/46.1M [00:00<00:00,    11.9MB/s]
 89%|########8 | Downloading : 41.0M/46.1M [00:00<00:00,    12.4MB/s]
 93%|#########3| Downloading : 43.0M/46.1M [00:00<00:00,    13.0MB/s]
 98%|#########7| Downloading : 45.0M/46.1M [00:00<00:00,    13.6MB/s]
100%|##########| Downloading : 46.1M/46.1M [00:00<00:00,    54.9MB/s]
Verifying hash adabd3b01fc7bb75c523a974f38ee3ae4e57b40f.
Downloading https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette//SC4001EC-Hypnogram.edf (5 kB)

  0%|          | Downloading : 0.00/4.51k [00:00<?,        ?B/s]
100%|##########| Downloading : 4.51k/4.51k [00:00<00:00,    7.92MB/s]
Verifying hash 21c998eadc8b1e3ea6727d3585186b8f76e7e70b.
Downloading https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette//SC4011E0-PSG.edf (48.7 MB)

  0%|          | Downloading : 0.00/48.7M [00:00<?,        ?B/s]
  0%|          | Downloading : 56.0k/48.7M [00:00<00:14,    3.41MB/s]
  0%|          | Downloading : 152k/48.7M [00:00<00:14,    3.46MB/s]
  1%|          | Downloading : 344k/48.7M [00:00<00:14,    3.58MB/s]
  1%|1         | Downloading : 600k/48.7M [00:00<00:13,    3.71MB/s]
  2%|1         | Downloading : 984k/48.7M [00:00<00:13,    3.85MB/s]
  2%|2         | Downloading : 1.21M/48.7M [00:00<00:12,    4.00MB/s]
  3%|2         | Downloading : 1.46M/48.7M [00:00<00:11,    4.15MB/s]
  4%|3         | Downloading : 1.71M/48.7M [00:00<00:11,    4.31MB/s]
  5%|4         | Downloading : 2.21M/48.7M [00:00<00:10,    4.47MB/s]
  6%|5         | Downloading : 2.71M/48.7M [00:00<00:10,    4.65MB/s]
  6%|6         | Downloading : 2.96M/48.7M [00:00<00:09,    4.82MB/s]
  7%|6         | Downloading : 3.21M/48.7M [00:00<00:09,    4.98MB/s]
  8%|7         | Downloading : 3.71M/48.7M [00:00<00:09,    5.17MB/s]
  9%|8         | Downloading : 4.21M/48.7M [00:00<00:08,    5.39MB/s]
 10%|9         | Downloading : 4.71M/48.7M [00:00<00:08,    5.60MB/s]
 11%|#         | Downloading : 5.21M/48.7M [00:00<00:07,    5.84MB/s]
 12%|#1        | Downloading : 5.71M/48.7M [00:00<00:07,    6.07MB/s]
 13%|#2        | Downloading : 6.21M/48.7M [00:00<00:07,    6.32MB/s]
 14%|#3        | Downloading : 6.71M/48.7M [00:00<00:06,    6.59MB/s]
 15%|#4        | Downloading : 7.21M/48.7M [00:00<00:06,    6.85MB/s]
 17%|#6        | Downloading : 8.21M/48.7M [00:00<00:05,    7.14MB/s]
 19%|#8        | Downloading : 9.21M/48.7M [00:00<00:05,    7.44MB/s]
 21%|##        | Downloading : 10.2M/48.7M [00:00<00:05,    7.75MB/s]
 23%|##3       | Downloading : 11.2M/48.7M [00:00<00:04,    8.08MB/s]
 25%|##5       | Downloading : 12.2M/48.7M [00:00<00:04,    8.42MB/s]
 27%|##7       | Downloading : 13.2M/48.7M [00:00<00:04,    8.78MB/s]
 29%|##9       | Downloading : 14.2M/48.7M [00:00<00:03,    9.16MB/s]
 31%|###1      | Downloading : 15.2M/48.7M [00:00<00:03,    9.54MB/s]
 33%|###3      | Downloading : 16.2M/48.7M [00:00<00:03,    9.95MB/s]
 35%|###5      | Downloading : 17.2M/48.7M [00:00<00:03,    10.4MB/s]
 37%|###7      | Downloading : 18.2M/48.7M [00:00<00:02,    10.8MB/s]
 41%|####1     | Downloading : 20.2M/48.7M [00:00<00:02,    11.3MB/s]
 46%|####5     | Downloading : 22.2M/48.7M [00:00<00:02,    11.8MB/s]
 50%|####9     | Downloading : 24.2M/48.7M [00:00<00:02,    12.3MB/s]
 54%|#####3    | Downloading : 26.2M/48.7M [00:00<00:01,    12.9MB/s]
 58%|#####7    | Downloading : 28.2M/48.7M [00:00<00:01,    13.5MB/s]
 62%|######1   | Downloading : 30.2M/48.7M [00:00<00:01,    14.1MB/s]
 66%|######6   | Downloading : 32.2M/48.7M [00:00<00:01,    14.7MB/s]
 70%|#######   | Downloading : 34.2M/48.7M [00:00<00:00,    15.3MB/s]
 74%|#######4  | Downloading : 36.2M/48.7M [00:00<00:00,    16.0MB/s]
 78%|#######8  | Downloading : 38.2M/48.7M [00:00<00:00,    16.7MB/s]
 82%|########2 | Downloading : 40.2M/48.7M [00:00<00:00,    17.4MB/s]
 87%|########6 | Downloading : 42.2M/48.7M [00:00<00:00,    18.1MB/s]
 91%|######### | Downloading : 44.2M/48.7M [00:00<00:00,    18.8MB/s]
 95%|#########4| Downloading : 46.2M/48.7M [00:01<00:00,    19.6MB/s]
 99%|#########8| Downloading : 48.2M/48.7M [00:01<00:00,    20.4MB/s]
100%|##########| Downloading : 48.7M/48.7M [00:01<00:00,    48.9MB/s]
Verifying hash 4d17451f7847355bcab17584de05e7e1df58c660.
Downloading https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette//SC4011EH-Hypnogram.edf (4 kB)

  0%|          | Downloading : 0.00/3.80k [00:00<?,        ?B/s]
100%|##########| Downloading : 3.80k/3.80k [00:00<00:00,    6.35MB/s]
Verifying hash d582a3cbe2db481a362af890bc5a2f5ca7c878dc.
Extracting EDF parameters from /home/circleci/mne_data/physionet-sleep-data/SC4001E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /home/circleci/mne_data/physionet-sleep-data/SC4011E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...

Preprocessing

Next, we preprocess the raw data. We apply convert the data to microvolts and apply a lowpass filter. We omit the downsampling step of 1 as the Sleep Physionet data is already sampled at a lower 100 Hz.

from braindecode.datautil.preprocess import (
    MNEPreproc, NumpyPreproc, preprocess)

high_cut_hz = 30

preprocessors = [
    # convert from volt to microvolt, directly modifying the numpy array
    NumpyPreproc(fn=lambda x: x * 1e6),
    # bandpass filter
    MNEPreproc(fn='filter', l_freq=None, h_freq=high_cut_hz),
]

# Transform the data
preprocess(dataset, preprocessors)

Out:

Reading 0 ... 2508000  =      0.000 ... 25080.000 secs...
Filtering raw data in 1 contiguous segment
Setting up low-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal lowpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 45 samples (0.450 sec)

Reading 0 ... 3261000  =      0.000 ... 32610.000 secs...
Filtering raw data in 1 contiguous segment
Setting up low-pass filter at 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal lowpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 45 samples (0.450 sec)

Extract windows

We extract 30-s windows to be used in the classification task.

from braindecode.datautil.windowers import create_windows_from_events


mapping = {  # We merge stages 3 and 4 following AASM standards.
    'Sleep stage W': 0,
    'Sleep stage 1': 1,
    'Sleep stage 2': 2,
    'Sleep stage 3': 3,
    'Sleep stage 4': 3,
    'Sleep stage R': 4
}

window_size_s = 30
sfreq = 100
window_size_samples = window_size_s * sfreq

windows_dataset = create_windows_from_events(
    dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0,
    window_size_samples=window_size_samples,
    window_stride_samples=window_size_samples, preload=True, mapping=mapping)

Out:

Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
837 matching events found
No baseline correction applied
0 projection items activated
Loading data for 837 events and 3000 original time points ...
0 bad epochs dropped
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1088 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1088 events and 3000 original time points ...
0 bad epochs dropped

Window preprocessing

We also preprocess the windows by applying channel-wise z-score normalization in each window.

from braindecode.datautil.preprocess import zscore

preprocess(windows_dataset, [MNEPreproc(fn=zscore)])

Split dataset into train and valid

We can easily split the dataset using additional info stored in the description attribute of braindecode.datasets.BaseDataset, in this case using the subject column. Here, we split the examples per subject.

splitted = windows_dataset.split('subject')
train_set = splitted['0']
valid_set = splitted['1']

# Print number of examples per class
print(train_set.datasets[0].windows)
print(valid_set.datasets[0].windows)

Out:

<Epochs |  837 events (all good), 0 - 29.99 sec, baseline off, ~38.3 MB, data loaded, with metadata,
 'Sleep stage 1': 58
 'Sleep stage 2': 250
 'Sleep stage 3': 220
 'Sleep stage 4': 220
 'Sleep stage R': 125
 'Sleep stage W': 184>
<Epochs |  1088 events (all good), 0 - 29.99 sec, baseline off, ~49.8 MB, data loaded, with metadata,
 'Sleep stage 1': 109
 'Sleep stage 2': 562
 'Sleep stage 3': 105
 'Sleep stage 4': 105
 'Sleep stage R': 170
 'Sleep stage W': 142>

Create model

We can now create the deep learning model. In this tutorial, we use the sleep staging architecture introduced in 1, which is a four-layer convolutional neural network.

import torch
from braindecode.util import set_random_seeds
from braindecode.models import SleepStagerChambon2018

cuda = torch.cuda.is_available()  # check if GPU is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True
# Set random seed to be able to reproduce results
set_random_seeds(seed=87, cuda=cuda)

n_classes = 5
# Extract number of channels and time steps from dataset
n_channels = train_set[0][0].shape[0]
input_size_samples = train_set[0][0].shape[1]

model = SleepStagerChambon2018(
    n_channels,
    sfreq,
    n_classes=n_classes,
    input_size_s=input_size_samples / sfreq
)

# Send model to GPU
if cuda:
    model.cuda()

Training

We can now train our network. braindecode.EEGClassifier is a braindecode object that is responsible for managing the training of neural networks. It inherits from skorch.NeuralNetClassifier, so the training logic is the same as in Skorch.

Note: We use different hyperparameters from 1, as these hyperparameters were optimized on a different dataset (MASS SS3) and with a different number of recordings. Generally speaking, it is recommended to perform hyperparameter optimization if reusing this code on a different dataset or with more recordings.

from skorch.helper import predefined_split
from skorch.callbacks import EpochScoring
from braindecode import EEGClassifier

lr = 5e-4
batch_size = 16
n_epochs = 5

train_bal_acc = EpochScoring(
    scoring='balanced_accuracy', on_train=True, name='train_bal_acc',
    lower_is_better=False)
valid_bal_acc = EpochScoring(
    scoring='balanced_accuracy', on_train=False, name='valid_bal_acc',
    lower_is_better=False)
callbacks = [('train_bal_acc', train_bal_acc),
             ('valid_bal_acc', valid_bal_acc)]

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    train_split=predefined_split(valid_set),  # using valid_set for validation
    optimizer__lr=lr,
    batch_size=batch_size,
    callbacks=callbacks,
    device=device
)
# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)

Out:

  epoch    train_bal_acc    train_loss    valid_bal_acc    valid_loss     dur
-------  ---------------  ------------  ---------------  ------------  ------
      1           0.2327        1.4823           0.3783        1.3577  4.0851
      2           0.4687        1.0553           0.5187        1.0603  2.8810
      3           0.5600        0.8073           0.5122        0.9819  6.9950
      4           0.5732        0.7450           0.5564        0.9850  2.1156
      5           0.5969        0.6878           0.5173        0.9374  1.8152

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=SleepStagerChambon2018(
    (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
    (feature_extractor): Sequential(
      (0): Conv2d(1, 8, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=(1, 12), stride=(1, 12), padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(8, 8, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=(1, 12), stride=(1, 12), padding=0, dilation=1, ceil_mode=False)
    )
    (fc): Sequential(
      (0): Dropout(p=0.25, inplace=False)
      (1): Linear(in_features=320, out_features=5, bias=True)
    )
  ),
)

Plot results

We use the history stored by Skorch during training to plot the performance of the model throughout training. Specifically, we plot the loss and the balanced misclassification rate (1 - balanced accuracy) for the training and validation sets.

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd

# Extract loss and balanced accuracy values for plotting from history object
df = pd.DataFrame(clf.history.to_list())
df[['train_mis_clf', 'valid_mis_clf']] = 100 - df[
    ['train_bal_acc', 'valid_bal_acc']] * 100

# get percent of misclass for better visual comparison to loss
plt.style.use('seaborn-talk')
fig, ax1 = plt.subplots(figsize=(8, 3))
df.loc[:, ['train_loss', 'valid_loss']].plot(
    ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False,
    fontsize=14)

ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

df.loc[:, ['train_mis_clf', 'valid_mis_clf']].plot(
    ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)
ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
ax2.set_ylabel('Balanced misclassification rate [%]', color='tab:red',
               fontsize=14)
ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
ax1.set_xlabel('Epoch', fontsize=14)

# where some data has already been plotted to ax
handles = []
handles.append(
    Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))
handles.append(
    Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))
plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
plt.tight_layout()
plot sleep staging

Finally, we also display the confusion matrix and classification report:

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

y_true = valid_set.datasets[0].windows.metadata['target'].values
y_pred = clf.predict(valid_set)

print(confusion_matrix(y_true, y_pred))

print(classification_report(y_true, y_pred))

Out:

[[128   0  12   1   1]
 [ 77   0  31   1   0]
 [ 70   0 480  11   1]
 [  0   0  19  86   0]
 [ 63   0 105   0   2]]
/home/circleci/.local/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
              precision    recall  f1-score   support

           0       0.38      0.90      0.53       142
           1       0.00      0.00      0.00       109
           2       0.74      0.85      0.79       562
           3       0.87      0.82      0.84       105
           4       0.50      0.01      0.02       170

    accuracy                           0.64      1088
   macro avg       0.50      0.52      0.44      1088
weighted avg       0.59      0.64      0.56      1088

Our model was able to perform reasonably well given the low amount of data available, reaching a balanced accuracy of around 55% in a 5-class classification task (chance-level = 20%) on held-out data.

To further improve performance, more recordings can be included in the training set, and various modifications can be made to the model (e.g., aggregating the representation of multiple consecutive windows 1).

Total running time of the script: ( 0 minutes 26.820 seconds)

Estimated memory usage: 146 MB

Gallery generated by Sphinx-Gallery