braindecode.models.DeepSleepNet#

class braindecode.models.DeepSleepNet(n_outputs=5, return_feats=False, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, activation_large=<class 'torch.nn.modules.activation.ELU'>, activation_small=<class 'torch.nn.modules.activation.ReLU'>, drop_prob=0.5)[source]#

DeepSleepNet from Supratak et al. (2017) [Supratak2017].

Convolution Recurrent

DeepSleepNet Architecture

Architectural Overview

DeepSleepNet couples dual-path convolution neural network representation learning with sequence residual learning via bidirectional LSTMs.

The network have:

  • (i) learns complementary, time-frequency features from each 30-s epoch using two parallel CNNs (small vs. large first-layer filters), then

  • (ii) models temporal dependencies across epochs using two-layer BiLSTMs with a residual shortcut from the CNN features, and finally

  • (iii) outputs per-epoch sleep stages. This design encodes both epoch-local patterns and longer-range transition rules used by human scorers.

In term of implementation:

  • (i) _RepresentationLearning two CNNs extract epoch-wise features (small-filter path for temporal precision; large-filter path for frequency precision);

  • (ii) _SequenceResidualLearning stacked BiLSTMs with peepholes + residual shortcut inject temporal context while preserving CNN evidence;

    1. _Classifier linear readout (softmax) for the five sleep stages.

Macro Components

  • _RepresentationLearning (dual-path CNN → epoch feature)

    First conv uses filter length ≈ Fs/2 and stride ≈ Fs/16 to emphasize timing of graphoelements.

  • Large-filter CNN:
    • Same stack but first conv uses filter length ≈ 4·Fs and

    • stride ≈ Fs/2 to emphasize frequency content.

  • Outputs from both paths are concatenated into the epoch embedding a_t.

  • Rationale. Two first-layer scales provide a learned, dual-scale filter bank that trades temporal vs. frequency precision without hand-crafted features.

  • _SequenceResidualLearning (BiLSTM context + residual fusion)

    • Operations.

    • Two-layer BiLSTM with peephole connections processes the sequence of epoch embeddings {a_t} forward and backward; hidden states from both directions are concatenated.

    • A shortcut MLP (fully connected + BatchNorm1d + ReLU) projects a_t to the BiLSTM output dimension and is added (residual) to the BiLSTM output at each time step.

    • Role. Encodes stage-transition rules and smooths predictions over time while preserving salient CNN features via the residual path.

  • _Classifier (epoch-wise prediction)

    • Operations.

    • Linear to produce per-epoch class probabilities.

Original training uses two-step optimization: CNN pretraining on class-balanced data, then end-to-end fine-tuning with sequential batches.

Convolutional Details

  • Temporal (where time-domain patterns are learned).

Both CNN paths use 1-D temporal convolutions. The small-filter path (first kernel ≈ Fs/2, stride ≈ Fs/16) captures when characteristic transients occur; the large-filter path (first kernel ≈ 4·Fs, stride ≈ Fs/2) captures which frequency components dominate over the epoch. Deeper layers use small kernels to refine features with fewer parameters, interleaved with max pooling for downsampling.

  • Spatial (how channels are processed).

The original model operates on single-channel raw EEG; convolutions therefore mix only along time (no spatial convolution across electrodes).

  • Spectral (how frequency information emerges).

No explicit Fourier/wavelet transform is used. The large-filter path serves as a frequency-sensitive analyzer, while the small-filter path remains time-sensitive, together functioning as a two-band learned filter bank at the first layer.

Attention / Sequential Modules

  • Type. Bidirectional LSTM (two layers) with peephole connections; forward and backward streams are independent and concatenated.

  • Shapes. For a sequence of N epochs, the CNN produces {a_t} R^{D}; BiLSTM outputs h_t R^{2H}; the shortcut MLP maps a_t R^{2H} to enable element-wise residual addition.

  • Role. Models long-range temporal dependencies (e.g., persisting N2 without visible K-complex/spindles), stabilizing per-epoch predictions.

Additional Mechanisms

  • Residual shortcut over sequence encoder. Adds projected CNN features to BiLSTM outputs, improving gradient flow and retaining discriminative content from representation learning.

  • Two-step training.
      1. Pretrain the CNN paths with class-balanced sampling;

      1. fine-tune the full network with sequential batches, using lower LR for CNNs and higher LR for the

    sequence encoder.

  • State handling. BiLSTM states are reinitialized per subject so that temporal context does not leak across recordings.

Usage and Configuration

  • Epoch pipeline. Use two parallel CNNs with the first conv sized to Fs/2 (small path) and 4·Fs (large path), with strides Fs/16 and Fs/2, respectively; stack three more conv blocks with small kernels, plus max pooling in each path. Concatenate path outputs to form epoch embeddings.

  • Sequence encoder. Apply two-layer BiLSTM (peepholes) over the sequence of embeddings; add a projection MLP on the CNN features and sum with BiLSTM outputs (residual). Finish with Linear per epoch.

  • Reference implementation. See the official repository for a faithful implementation and training scripts.

Parameters:
  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • return_feats (bool) – If True, return the features, i.e. the output of the feature extractor (before the final linear layer). If False, pass the features through the final linear layer.

  • n_chans (int) – Number of EEG channels.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • n_times (int) – Number of time samples of the input window.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • activation_large (nn.Module, default=nn.ELU) – Activation function class to apply. Should be a PyTorch activation module class like nn.ReLU or nn.ELU. Default is nn.ELU.

  • activation_small (nn.Module, default=nn.ReLU) – Activation function class to apply. Should be a PyTorch activation module class like nn.ReLU or nn.ELU. Default is nn.ReLU.

  • drop_prob (float, default=0.5) – The dropout rate for regularization. Values should be between 0 and 1.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.

References

[Supratak2017]

Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017). DeepSleepNet: A model for automatic sleep stage scoring based on raw single-channel EEG. IEEE Transactions on Neural Systems and Rehabilitation Engineering, 25(11), 1998-2008.

Methods

forward(x)[source]#

Forward pass.

Parameters:

x (torch.Tensor) – Batch of EEG windows of shape (batch_size, n_channels, n_times).