braindecode.models.DeepSleepNet#
- class braindecode.models.DeepSleepNet(n_outputs=5, return_feats=False, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, activation_large=<class 'torch.nn.modules.activation.ELU'>, activation_small=<class 'torch.nn.modules.activation.ReLU'>, drop_prob=0.5)[source]#
DeepSleepNet from Supratak et al. (2017) [Supratak2017].
Convolution Recurrent
Architectural Overview
DeepSleepNet couples dual-path convolution neural network representation learning with sequence residual learning via bidirectional LSTMs.
The network have:
(i) learns complementary, time-frequency features from each 30-s epoch using two parallel CNNs (small vs. large first-layer filters), then
(ii) models temporal dependencies across epochs using two-layer BiLSTMs with a residual shortcut from the CNN features, and finally
(iii) outputs per-epoch sleep stages. This design encodes both epoch-local patterns and longer-range transition rules used by human scorers.
In term of implementation:
(i)
_RepresentationLearning
two CNNs extract epoch-wise features (small-filter path for temporal precision; large-filter path for frequency precision);(ii)
_SequenceResidualLearning
stacked BiLSTMs with peepholes + residual shortcut inject temporal context while preserving CNN evidence;_Classifier
linear readout (softmax) for the five sleep stages.
Macro Components
_RepresentationLearning
(dual-path CNN → epoch feature)Operations.
Small-filter CNN 4 times:
MaxPool1d
after.
First conv uses filter length ≈ Fs/2 and stride ≈ Fs/16 to emphasize timing of graphoelements.
- Large-filter CNN:
Same stack but first conv uses filter length ≈ 4·Fs and
stride ≈ Fs/2 to emphasize frequency content.
Outputs from both paths are concatenated into the epoch embedding
a_t
.Rationale. Two first-layer scales provide a learned, dual-scale filter bank that trades temporal vs. frequency precision without hand-crafted features.
_SequenceResidualLearning
(BiLSTM
context + residual fusion)Operations.
Two-layer BiLSTM with peephole connections processes the sequence of epoch embeddings
{a_t}
forward and backward; hidden states from both directions are concatenated.A shortcut MLP (fully connected +
BatchNorm1d
+ReLU
) projectsa_t
to the BiLSTM output dimension and is added (residual) to theBiLSTM
output at each time step.Role. Encodes stage-transition rules and smooths predictions over time while preserving salient CNN features via the residual path.
_Classifier
(epoch-wise prediction)Operations.
Linear
to produce per-epoch class probabilities.
Original training uses two-step optimization: CNN pretraining on class-balanced data, then end-to-end fine-tuning with sequential batches.
Convolutional Details
Temporal (where time-domain patterns are learned).
Both CNN paths use 1-D temporal convolutions. The small-filter path (first kernel ≈ Fs/2, stride ≈ Fs/16) captures when characteristic transients occur; the large-filter path (first kernel ≈ 4·Fs, stride ≈ Fs/2) captures which frequency components dominate over the epoch. Deeper layers use small kernels to refine features with fewer parameters, interleaved with max pooling for downsampling.
Spatial (how channels are processed).
The original model operates on single-channel raw EEG; convolutions therefore mix only along time (no spatial convolution across electrodes).
Spectral (how frequency information emerges).
No explicit Fourier/wavelet transform is used. The large-filter path serves as a frequency-sensitive analyzer, while the small-filter path remains time-sensitive, together functioning as a two-band learned filter bank at the first layer.
Attention / Sequential Modules
Type. Bidirectional LSTM (two layers) with peephole connections; forward and backward streams are independent and concatenated.
Shapes. For a sequence of
N
epochs, the CNN produces{a_t} ∈ R^{D}
; BiLSTM outputsh_t ∈ R^{2H}
; the shortcut MLP mapsa_t → R^{2H}
to enable element-wise residual addition.Role. Models long-range temporal dependencies (e.g., persisting N2 without visible K-complex/spindles), stabilizing per-epoch predictions.
Additional Mechanisms
Residual shortcut over sequence encoder. Adds projected CNN features to BiLSTM outputs, improving gradient flow and retaining discriminative content from representation learning.
- Two-step training.
Pretrain the CNN paths with class-balanced sampling;
fine-tune the full network with sequential batches, using lower LR for CNNs and higher LR for the
sequence encoder.
State handling. BiLSTM states are reinitialized per subject so that temporal context does not leak across recordings.
Usage and Configuration
Epoch pipeline. Use two parallel CNNs with the first conv sized to Fs/2 (small path) and 4·Fs (large path), with strides Fs/16 and Fs/2, respectively; stack three more conv blocks with small kernels, plus max pooling in each path. Concatenate path outputs to form epoch embeddings.
Sequence encoder. Apply two-layer BiLSTM (peepholes) over the sequence of embeddings; add a projection MLP on the CNN features and sum with BiLSTM outputs (residual). Finish with
Linear
per epoch.Reference implementation. See the official repository for a faithful implementation and training scripts.
- Parameters:
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
return_feats (bool) – If True, return the features, i.e. the output of the feature extractor (before the final linear layer). If False, pass the features through the final linear layer.
n_chans (int) – Number of EEG channels.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]
. Refer tomne.Info
for more details.n_times (int) – Number of time samples of the input window.
input_window_seconds (float) – Length of the input window in seconds.
sfreq (float) – Sampling frequency of the EEG recordings.
activation_large (nn.Module, default=nn.ELU) – Activation function class to apply. Should be a PyTorch activation module class like
nn.ReLU
ornn.ELU
. Default isnn.ELU
.activation_small (nn.Module, default=nn.ReLU) – Activation function class to apply. Should be a PyTorch activation module class like
nn.ReLU
ornn.ELU
. Default isnn.ReLU
.drop_prob (float, default=0.5) – The dropout rate for regularization. Values should be between 0 and 1.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.
References
[Supratak2017]Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017). DeepSleepNet: A model for automatic sleep stage scoring based on raw single-channel EEG. IEEE Transactions on Neural Systems and Rehabilitation Engineering, 25(11), 1998-2008.
Methods
- forward(x)[source]#
Forward pass.
- Parameters:
x (torch.Tensor) – Batch of EEG windows of shape (batch_size, n_channels, n_times).