braindecode.models.EEGPT#
- class braindecode.models.EEGPT(n_outputs=None, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, patch_size=64, patch_stride=32, embed_num=4, embed_dim=512, depth=8, num_heads=8, mlp_ratio=4.0, drop_prob=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, init_std=0.02, qkv_bias=True, patch_module=None, norm_layer=None, layer_norm_eps=1e-06, return_encoder_output=False, chan_proj_type='conv1d_constraint', n_chans_target=19, chan_conv_max_norm=1.0, final_layer=None)[source]#
EEGPT: Pretrained Transformer for Universal and Reliable Representation of EEG Signals from Wang et al. (2024) [eegpt].
Foundation Model Attention/Transformer
a) The EEGPT structure involves patching the input EEG signal as \(p_{i,j}\) through masking (50% time and 80% channel patches), creating masked part \(\mathcal{M}\) and unmasked part \(\bar{\mathcal{M}}\). b) Local spatio-temporal embedding maps patches to tokens. c) Use of dual self-supervised learning with Spatio-Temporal Representation Alignment and Mask-based Reconstruction.#
EEGPT is a pretrained transformer model designed for universal EEG feature extraction. It addresses challenges like low SNR and inter-subject variability by employing a dual self-supervised learning method that combines Spatio-Temporal Representation Alignment and Mask-based Reconstruction [eegpt].
Model Overview (Layer-by-layer)
Patch embedding (
_PatchEmbedor_PatchNormEmbed): split each channel intopatch_sizetime patches and project toembed_dim, yielding tokens with shape(batch, n_patches, n_chans, embed_dim).Channel embedding (
chan_embed): add a learned embedding for each channel to preserve spatial identity before attention.Transformer encoder blocks (
_EEGTransformer.blocks): for each patch group, appendembed_numlearned summary tokens and process the sequence with multi-head self-attention and MLP layers.Summary extraction: keep only the summary tokens, apply
normif set, and reshape back to(batch, n_patches, embed_num, embed_dim).Task head (
final_layer): flatten summary tokens across patches and map ton_outputs; ifreturn_encoder_output=True, return the encoder features instead.
Dual Self-Supervised Learning
EEGPT moves beyond simple masked reconstruction by introducing a representation alignment objective. The pretraining loss \(\mathcal{L}\) is the sum of alignment loss \(\mathcal{L}_A\) and reconstruction loss \(\mathcal{L}_R\):
\[\mathcal{L} = \mathcal{L}_A + \mathcal{L}_R\]Spatio-Temporal Representation Alignment: (\(\mathcal{L}_A\)) Aligns the predicted features of masked regions with global features extracted by a Momentum Encoder. This forces the model to learn semantic, high-level representations rather than just signal waveform details.
\[\mathcal{L}_A = - \frac{1}{N} \sum_{j=1}^{N} ||pred_j - LN(menc_j)||_2^2\]where \(pred_j\) is the predictor output and \(menc_j\) is the momentum encoder output.
Mask-based Reconstruction: (\(\mathcal{L}_R\)) Standard masked autoencoder objective to reconstruct the raw EEG patches, ensuring local temporal fidelity.
\[\mathcal{L}_R = - \frac{1}{|\mathcal{M}|} \sum_{(i,j) \in \mathcal{M}} ||rec_{i,j} - LN(p_{i,j})||_2^2\]where \(rec_{i,j}\) is the reconstructed patch and \(p_{i,j}\) is the original patch.
Macro Components
- EEGPT.target_encoder (Universal Encoder)
Operations. A hierarchical backbone that consists of Local Spatio-Temporal Embedding followed by a standard Transformer encoder [eegpt].
Role. Maps raw spatio-temporal EEG patches into a sequence of latent tokens \(z\).
- EEGPT.chans_id (Channel Identification)
Operations. A buffer containing channel indices mapped from the standard channel names provided in
chs_info[eegpt].Role. Provides the spatial identity for each input channel, allowing the model to look up the correct channel embedding vector \(\varsigma_i\).
- Local Spatio-Temporal Embedding (Input Processing)
Operations. The input signal \(X\) is chunked into patches \(p_{i,j}\). Each patch is linearly projected and summed with a specific channel embedding: \(token_{i,j} = \text{Embed}(p_{i,j}) + \varsigma_i\) [eegpt].
Role. Converts the 2D EEG grid (Channels \(\times\) Time) into a unified sequence of tokens that preserves both channel identity and temporal order.
How the information is encoded temporally, spatially, and spectrally
Temporal. The model segments continuous EEG signals into small, non-overlapping patches (e.g., 250ms windows with
patch_size=64) [eegpt]. This Patching mechanism captures short-term local temporal structure, while the subsequent Transformer encoder captures long-range temporal dependencies across the entire window.Spatial. Unlike convolutional models that may rely on fixed spatial order, EEGPT uses Channel Embeddings \(\varsigma_i\) [eegpt]. Each channel’s data is treated as a distinct sequence of tokens tagged with its spatial identity. This allows the model to flexibly handle different montages and missing channels by simply mapping channel names to their corresponding learnable embeddings.
Spectral. Spectral information is implicitly learned through the Mask-based Reconstruction objective (\(\mathcal{L}_R\)) [eegpt]. By forcing the model to reconstruct raw waveforms (including phase and amplitude) from masked inputs, the model learns to encode frequency-specific patterns necessary refines this by encouraging these spectral features to align with robust, high-level semantic representations.
Pretrained Weights
Weights are available on HuggingFace.
Important
Pre-trained Weights Available
This model has pre-trained weights available on the Hugging Face Hub. Link here.
You can load them using:
from braindecode.models import EEGPT # Load pre-trained model from Hugging Face Hub model = EEGPT.from_pretrained("braindecode/eegpt-pretrained")
To push your own trained model to the Hub:
# After training your model model.push_to_hub( repo_id="username/my-eegpt-model", commit_message="Upload trained EEGPT model" )
Requires installing
braindecode[hug]for Hub integration.Usage
The model can be initialized for specific downstream tasks (e.g., classification) by specifying n_outputs, chs_info, n_times.
from braindecode.models import EEGPT model = EEGPT( n_chans=22, n_times=1000, chs_info=chs_info, n_outputs=4, # For classification tasks patch_size=64, depth=8, embed_dim=512, ) # Forward pass # Input shape: (batch_size, n_chans, n_times) y = model(x)
- Parameters:
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
n_chans (int) – Number of EEG channels.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]. Refer tomne.Infofor more details.n_times (int) – Number of time samples of the input window.
input_window_seconds (float) – Length of the input window in seconds.
sfreq (float) – Sampling frequency of the EEG recordings.
patch_size (
int) – Size of the patches for the transformer.patch_stride (
int) – Stride of the patches for the transformer.embed_num (
int) – Number of summary tokens used for the global representation.embed_dim (
int) – Dimension of the embeddings.depth (
int) – Number of transformer layers.num_heads (
int) – Number of attention heads.mlp_ratio (
float) – Ratio of the MLP hidden dimension to the embedding dimension.drop_prob (
float) – Dropout probability.attn_drop_rate (
float) – Attention dropout rate.drop_path_rate (
float) – Drop path rate.init_std (
float) – Standard deviation for weight initialization.qkv_bias (
bool) – Whether to use bias in the QKV projection.patch_module (
Optional[Module]) – The description is missing.norm_layer (
Optional[Module]) – Normalization layer. If None, defaults tonn.LayerNormwith epsilonlayer_norm_eps.layer_norm_eps (
float) – Epsilon value for the normalization layer.return_encoder_output (
bool) – Whether to return the encoder output or the classifier output.chan_proj_type (
Literal['conv1d_constraint','linear','none']) – The description is missing.n_chans_target (
int) – The description is missing.chan_conv_max_norm (
float) – The description is missing.final_layer (
type[Module] |None) – The description is missing.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
When loading pretrained weights from the original EEGPT checkpoint (e.g., for fine-tuning), you may encounter “unexpected keys” related to the predictor and reconstructor modules (e.g., predictor.mask_token, reconstructor.time_embed). These components are used only during the self-supervised pre-training phase (Masked Auto-Encoder) and are not part of this encoder-only model used for downstream tasks. It is safe to ignore them.
Hugging Face Hub integration
When the optional
huggingface_hubpackage is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:pip install braindecode[hub]
Pushing a model to the Hub:
from braindecode.models import EEGPT # Train your model model = EEGPT(n_chans=22, n_outputs=4, n_times=1000) # ... training code ... # Push to the Hub model.push_to_hub( repo_id="username/my-eegpt-model", commit_message="Initial model upload", )
Loading a model from the Hub:
from braindecode.models import EEGPT # Load pretrained model model = EEGPT.from_pretrained("username/my-eegpt-model") # Load with a different number of outputs (head is rebuilt automatically) model = EEGPT.from_pretrained("username/my-eegpt-model", n_outputs=4)
Extracting features and replacing the head:
import torch x = torch.randn(1, model.n_chans, model.n_times) # Extract encoder features (consistent dict across all models) out = model(x, return_features=True) features = out["features"] # Replace the classification head model.reset_head(n_outputs=10)
Saving and restoring full configuration:
import json config = model.get_config() # all __init__ params with open("config.json", "w") as f: json.dump(config, f) model2 = EEGPT.from_config(config) # reconstruct (no weights)
All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.
See Loading and Adapting Pretrained Foundation Models for a complete tutorial.
References
[eegpt] (1,2,3,4,5,6,7,8)Wang, G., Liu, W., He, Y., Xu, C., Ma, L., & Li, H. (2024). EEGPT: Pretrained transformer for universal and reliable representation of eeg signals. Advances in Neural Information Processing Systems, 37, 39249-39280. Online: https://proceedings.neurips.cc/paper_files/paper/2024/file/4540d267eeec4e5dbd9dae9448f0b739-Paper-Conference.pdf
Methods
- forward(x, return_features=False)[source]#
Forward pass.
- Parameters:
x (torch.Tensor) – EEG data of shape (batch, n_chans, n_times).
return_features (bool) – If True, return a dict with
"features"and"cls_token"instead of the classification output.
- Returns:
Model output. Shape depends on n_outputs and return_encoder_output.
- Return type:
torch.Tensor or dict
- get_probe_params()[source]#
Get parameters needed to create a _LinearConstraintProbe.
- Returns:
Parameters dict with n_patches, embed_num, embed_dim.
- Return type:
- reset_head(n_outputs)[source]#
Replace the classification head for a new number of outputs.
This is called automatically by
from_pretrained()when the user passes ann_outputsthat differs from the saved config. Override in subclasses that need a model-specific head structure.- Parameters:
n_outputs (int) – New number of output classes.
Examples
>>> from braindecode.models import BENDR >>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4) >>> model.reset_head(10) >>> model.n_outputs 10
Added in version 1.4.