braindecode.models.DeepSleepNet#
- class braindecode.models.DeepSleepNet(n_outputs=5, return_feats=False, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, activation_large=<class 'torch.nn.modules.activation.ELU'>, activation_small=<class 'torch.nn.modules.activation.ReLU'>, drop_prob=0.5)[source]#
 DeepSleepNet from Supratak et al. (2017) [Supratak2017].
Convolution Recurrent
Architectural Overview
DeepSleepNet couples dual-path convolution neural network representation learning with sequence residual learning via bidirectional LSTMs.
The network have:
(i) learns complementary, time-frequency features from each 30-s epoch using two parallel CNNs (small vs. large first-layer filters), then
(ii) models temporal dependencies across epochs using two-layer BiLSTMs with a residual shortcut from the CNN features, and finally
(iii) outputs per-epoch sleep stages. This design encodes both epoch-local patterns and longer-range transition rules used by human scorers.
In term of implementation:
(i)
_RepresentationLearningtwo CNNs extract epoch-wise features (small-filter path for temporal precision; large-filter path for frequency precision);(ii)
_SequenceResidualLearningstacked BiLSTMs with peepholes + residual shortcut inject temporal context while preserving CNN evidence;_Classifierlinear readout (softmax) for the five sleep stages.
Macro Components
_RepresentationLearning(dual-path CNN → epoch feature)Operations.
Small-filter CNN 4 times:
MaxPool1dafter.
First conv uses filter length ≈ Fs/2 and stride ≈ Fs/16 to emphasize timing of graphoelements.
Large-filter CNN:
Same stack but first conv uses filter length ≈ 4·Fs and
stride ≈ Fs/2 to emphasize frequency content.
Outputs from both paths are concatenated into the epoch embedding
a_t.Rationale. Two first-layer scales provide a learned, dual-scale filter bank that trades temporal vs. frequency precision without hand-crafted features.
_SequenceResidualLearning(BiLSTMcontext + residual fusion)Operations.
Two-layer BiLSTM with peephole connections processes the sequence of epoch embeddings
{a_t}forward and backward; hidden states from both directions are concatenated.A shortcut MLP (fully connected +
BatchNorm1d+ReLU) projectsa_tto the BiLSTM output dimension and is added (residual) to theBiLSTMoutput at each time step.Role. Encodes stage-transition rules and smooths predictions over time while preserving salient CNN features via the residual path.
_Classifier(epoch-wise prediction)Operations.
Linearto produce per-epoch class probabilities.
Original training uses two-step optimization: CNN pretraining on class-balanced data, then end-to-end fine-tuning with sequential batches.
Convolutional Details
Temporal (where time-domain patterns are learned).
Both CNN paths use 1-D temporal convolutions. The small-filter path (first kernel ≈ Fs/2, stride ≈ Fs/16) captures when characteristic transients occur; the large-filter path (first kernel ≈ 4·Fs, stride ≈ Fs/2) captures which frequency components dominate over the epoch. Deeper layers use small kernels to refine features with fewer parameters, interleaved with max pooling for downsampling.
Spatial (how channels are processed).
The original model operates on single-channel raw EEG; convolutions therefore mix only along time (no spatial convolution across electrodes).
Spectral (how frequency information emerges).
No explicit Fourier/wavelet transform is used. The large-filter path serves as a frequency-sensitive analyzer, while the small-filter path remains time-sensitive, together functioning as a two-band learned filter bank at the first layer.
Attention / Sequential Modules
Type. Bidirectional LSTM (two layers) with peephole connections; forward and backward streams are independent and concatenated.
Shapes. For a sequence of
Nepochs, the CNN produces{a_t} ∈ R^{D}; BiLSTM outputsh_t ∈ R^{2H}; the shortcut MLP mapsa_t → R^{2H}to enable element-wise residual addition.Role. Models long-range temporal dependencies (e.g., persisting N2 without visible K-complex/spindles), stabilizing per-epoch predictions.
Additional Mechanisms
Residual shortcut over sequence encoder. Adds projected CNN features to BiLSTM outputs, improving gradient flow and retaining discriminative content from representation learning.
Two-step training.
Pretrain the CNN paths with class-balanced sampling;
(ii) fine-tune the full network with sequential batches, using lower LR for CNNs and higher LR for the sequence encoder.
State handling. BiLSTM states are reinitialized per subject so that temporal context does not leak across recordings.
Usage and Configuration
Epoch pipeline. Use two parallel CNNs with the first conv sized to Fs/2 (small path) and 4·Fs (large path), with strides Fs/16 and Fs/2, respectively; stack three more conv blocks with small kernels, plus max pooling in each path. Concatenate path outputs to form epoch embeddings.
Sequence encoder. Apply two-layer BiLSTM (peepholes) over the sequence of embeddings; add a projection MLP on the CNN features and sum with BiLSTM outputs (residual). Finish with
Linearper epoch.Reference implementation. See the official repository for a faithful implementation and training scripts.
- Parameters:
 n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
return_feats (bool) – If True, return the features, i.e. the output of the feature extractor (before the final linear layer). If False, pass the features through the final linear layer.
n_chans (int) – Number of EEG channels.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]. Refer tomne.Infofor more details.n_times (int) – Number of time samples of the input window.
input_window_seconds (float) – Length of the input window in seconds.
sfreq (float) – Sampling frequency of the EEG recordings.
activation_large (nn.Module, default=nn.ELU) – Activation function class to apply. Should be a PyTorch activation module class like
nn.ReLUornn.ELU. Default isnn.ELU.activation_small (nn.Module, default=nn.ReLU) – Activation function class to apply. Should be a PyTorch activation module class like
nn.ReLUornn.ELU. Default isnn.ReLU.drop_prob (float, default=0.5) – The dropout rate for regularization. Values should be between 0 and 1.
- Raises:
 ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.
Hugging Face Hub integration
When the optional
huggingface_hubpackage is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:pip install braindecode[hug]
Pushing a model to the Hub:
from braindecode.models import EEGNetv4 # Train your model model = EEGNetv4(n_chans=22, n_outputs=4, n_times=1000) # ... training code ... # Push to the Hub model.push_to_hub( repo_id="username/my-eegnet-model", commit_message="Initial model upload" )
Loading a model from the Hub:
from braindecode.models import EEGNetv4 # Load pretrained model model = EEGNetv4.from_pretrained("username/my-eegnet-model")
The integration automatically handles EEG-specific parameters (n_chans, n_times, sfreq, chs_info, etc.) by saving them in a config file alongside the model weights. This ensures that loaded models are correctly configured for their original data specifications.
Important
Currently, only EEG-specific parameters (n_outputs, n_chans, n_times, input_window_seconds, sfreq, chs_info) are saved to the Hub. Model-specific parameters (e.g., dropout rates, activation functions, number of filters) are not preserved and will use their default values when loading from the Hub.
To use non-default model parameters, specify them explicitly when calling
from_pretrained():model = EEGNet.from_pretrained("user/model", dropout=0.3, activation='relu')
Full parameter serialization will be addressed in a future update.
References
[Supratak2017]Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017). DeepSleepNet: A model for automatic sleep stage scoring based on raw single-channel EEG. IEEE Transactions on Neural Systems and Rehabilitation Engineering, 25(11), 1998-2008.
Methods
- forward(x)[source]#
 Forward pass.
- Parameters:
 x (torch.Tensor) – Batch of EEG windows of shape (batch_size, n_channels, n_times).