braindecode.models.EEGPT#

class braindecode.models.EEGPT(n_outputs=None, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, patch_size=64, patch_stride=32, embed_num=4, embed_dim=512, depth=8, num_heads=8, mlp_ratio=4.0, drop_prob=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, init_std=0.02, qkv_bias=True, patch_module=None, norm_layer=None, layer_norm_eps=1e-06, return_encoder_output=False)[source]#

EEGPT: Pretrained Transformer for Universal and Reliable Representation of EEG Signals from Wang et al. (2024) [eegpt].

Foundation Model Attention/Transformer

EEGPT Architecture

a) The EEGPT structure involves patching the input EEG signal as \(p_{i,j}\) through masking (50% time and 80% channel patches), creating masked part \(\mathcal{M}\) and unmasked part \(\bar{\mathcal{M}}\). b) Local spatio-temporal embedding maps patches to tokens. c) Use of dual self-supervised learning with Spatio-Temporal Representation Alignment and Mask-based Reconstruction.#

EEGPT is a pretrained transformer model designed for universal EEG feature extraction. It addresses challenges like low SNR and inter-subject variability by employing a dual self-supervised learning method that combines Spatio-Temporal Representation Alignment and Mask-based Reconstruction [eegpt].

Model Overview (Layer-by-layer)

  1. Patch embedding (_PatchEmbed or _PatchNormEmbed): split each channel into patch_size time patches and project to embed_dim, yielding tokens with shape (batch, n_patches, n_chans, embed_dim).

  2. Channel embedding (chan_embed): add a learned embedding for each channel to preserve spatial identity before attention.

  3. Transformer encoder blocks (_EEGTransformer.blocks): for each patch group, append embed_num learned summary tokens and process the sequence with multi-head self-attention and MLP layers.

  4. Summary extraction: keep only the summary tokens, apply norm if set, and reshape back to (batch, n_patches, embed_num, embed_dim).

  5. Task head (final_layer): flatten summary tokens across patches and map to n_outputs; if return_encoder_output=True, return the encoder features instead.

Dual Self-Supervised Learning

EEGPT moves beyond simple masked reconstruction by introducing a representation alignment objective. The pretraining loss \(\mathcal{L}\) is the sum of alignment loss \(\mathcal{L}_A\) and reconstruction loss \(\mathcal{L}_R\):

\[\mathcal{L} = \mathcal{L}_A + \mathcal{L}_R\]
  1. Spatio-Temporal Representation Alignment: (\(\mathcal{L}_A\)) Aligns the predicted features of masked regions with global features extracted by a Momentum Encoder. This forces the model to learn semantic, high-level representations rather than just signal waveform details.

    \[\mathcal{L}_A = - \frac{1}{N} \sum_{j=1}^{N} ||pred_j - LN(menc_j)||_2^2\]

    where \(pred_j\) is the predictor output and \(menc_j\) is the momentum encoder output.

  2. Mask-based Reconstruction: (\(\mathcal{L}_R\)) Standard masked autoencoder objective to reconstruct the raw EEG patches, ensuring local temporal fidelity.

    \[\mathcal{L}_R = - \frac{1}{|\mathcal{M}|} \sum_{(i,j) \in \mathcal{M}} ||rec_{i,j} - LN(p_{i,j})||_2^2\]

    where \(rec_{i,j}\) is the reconstructed patch and \(p_{i,j}\) is the original patch.

Macro Components

  • EEGPT.target_encoder (Universal Encoder)
    • Operations. A hierarchical backbone that consists of Local Spatio-Temporal Embedding followed by a standard Transformer encoder [eegpt].

    • Role. Maps raw spatio-temporal EEG patches into a sequence of latent tokens \(z\).

  • EEGPT.chans_id (Channel Identification)
    • Operations. A buffer containing channel indices mapped from the standard channel names provided in chs_info [eegpt].

    • Role. Provides the spatial identity for each input channel, allowing the model to look up the correct channel embedding vector \(\varsigma_i\).

  • Local Spatio-Temporal Embedding (Input Processing)
    • Operations. The input signal \(X\) is chunked into patches \(p_{i,j}\). Each patch is linearly projected and summed with a specific channel embedding: \(token_{i,j} = \text{Embed}(p_{i,j}) + \varsigma_i\) [eegpt].

    • Role. Converts the 2D EEG grid (Channels \(\times\) Time) into a unified sequence of tokens that preserves both channel identity and temporal order.

How the information is encoded temporally, spatially, and spectrally

  • Temporal. The model segments continuous EEG signals into small, non-overlapping patches (e.g., 250ms windows with patch_size=64) [eegpt]. This Patching mechanism captures short-term local temporal structure, while the subsequent Transformer encoder captures long-range temporal dependencies across the entire window.

  • Spatial. Unlike convolutional models that may rely on fixed spatial order, EEGPT uses Channel Embeddings \(\varsigma_i\) [eegpt]. Each channel’s data is treated as a distinct sequence of tokens tagged with its spatial identity. This allows the model to flexibly handle different montages and missing channels by simply mapping channel names to their corresponding learnable embeddings.

  • Spectral. Spectral information is implicitly learned through the Mask-based Reconstruction objective (\(\mathcal{L}_R\)) [eegpt]. By forcing the model to reconstruct raw waveforms (including phase and amplitude) from masked inputs, the model learns to encode frequency-specific patterns necessary refines this by encouraging these spectral features to align with robust, high-level semantic representations.

Pretrained Weights

Weights are available on HuggingFace.

Important

Pre-trained Weights Available

This model has pre-trained weights available on the Hugging Face Hub. Link here.

You can load them using:

from braindecode.models import EEGPT

# Load pre-trained model from Hugging Face Hub
model = EEGPT.from_pretrained("braindecode/eegpt-pretrained")

To push your own trained model to the Hub:

# After training your model
model.push_to_hub(
    repo_id="username/my-eegpt-model", commit_message="Upload trained EEGPT model"
)

Requires installing braindecode[hug] for Hub integration.

Usage

The model can be initialized for specific downstream tasks (e.g., classification) by specifying n_outputs, chs_info, n_times.

from braindecode.models import EEGPT

model = EEGPT(
    n_chans=22,
    n_times=1000,
    chs_info=chs_info,
    n_outputs=4,  # For classification tasks
    patch_size=64,
    depth=8,
    embed_dim=512,
)

# Forward pass
# Input shape: (batch_size, n_chans, n_times)
y = model(x)
Parameters:
  • return_encoder_output (bool, default=False) – Whether to return the encoder output or the classifier output.

  • patch_size (int, default=64) – Size of the patches for the transformer.

  • patch_stride (int, default=32) – Stride of the patches for the transformer.

  • embed_num (int, default=4) – Number of summary tokens used for the global representation.

  • embed_dim (int, default=512) – Dimension of the embeddings.

  • depth (int, default=8) – Number of transformer layers.

  • num_heads (int, default=8) – Number of attention heads.

  • mlp_ratio (float, default=4.0) – Ratio of the MLP hidden dimension to the embedding dimension.

  • drop_prob (float, default=0.0) – Dropout probability.

  • attn_drop_rate (float, default=0.0) – Attention dropout rate.

  • drop_path_rate (float, default=0.0) – Drop path rate.

  • init_std (float, default=0.02) – Standard deviation for weight initialization.

  • qkv_bias (bool, default=True) – Whether to use bias in the QKV projection.

  • norm_layer (torch.nn.Module, default=None) – Normalization layer. If None, defaults to nn.LayerNorm with epsilon layer_norm_eps.

  • layer_norm_eps (float, default=1e-6) – Epsilon value for the normalization layer.

References

[eegpt] (1,2,3,4,5,6,7,8)

Wang, G., Liu, W., He, Y., Xu, C., Ma, L., & Li, H. (2024). EEGPT: Pretrained transformer for universal and reliable representation of eeg signals. Advances in Neural Information Processing Systems, 37, 39249-39280. Online: https://proceedings.neurips.cc/paper_files/paper/2024/file/4540d267eeec4e5dbd9dae9448f0b739-Paper-Conference.pdf

Notes

When loading pretrained weights from the original EEGPT checkpoint (e.g., for fine-tuning), you may encounter “unexpected keys” related to the predictor and reconstructor modules (e.g., predictor.mask_token, reconstructor.time_embed). These components are used only during the self-supervised pre-training phase (Masked Auto-Encoder) and are not part of this encoder-only model used for downstream tasks. It is safe to ignore them.

Methods

forward(x)[source]#

Forward pass.

Parameters:

x (torch.Tensor) – EEG data of shape (batch, n_chans, n_times).

Returns:

Model output. Shape depends on n_outputs and return_encoder_output.

Return type:

torch.Tensor