braindecode.models.STEEGFormer#

class braindecode.models.STEEGFormer(n_outputs=None, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, patch_size=16, embed_dim=512, depth=8, num_heads=8, mlp_ratio=4, drop_prob=0.0, drop_path=0.0, activation=<class 'torch.nn.modules.activation.GELU'>, global_pool='avg', n_chans_pos=145, chan_pos_idx=None)[source]#

STEEGFormer from Yang et al. (2026) [Yang2026].

Attention/Transformer Foundation Model

https://raw.githubusercontent.com/LiuyinYang1101/STEEGFormer/main/assets/graphic_overview.png

ST-EEGFormer architecture, reproduced from the official repository (LiuyinYang1101/STEEGFormer).#

Added in version 1.6.1.

Architecture Overview

ViT-based EEG foundation model, pre-trained with a Masked Autoencoder (MAE) objective on raw EEG. Each channel is cut into non-overlapping temporal patches that are linearly embedded into tokens, augmented with temporal and channel positional information, prepended with a learned CLS token, encoded by a stack of pre-norm Transformer blocks, and read out by a linear head.

Macro Components

  • Patch + token embedding (STEEGFormer.patch_embed). Operations: cut each channel into seq = n_times // patch_size non-overlapping patches of patch_size samples and linearly project each to an embed_dim token. Role: turn a C-channel segment into C * seq tokens (one per (channel, time-patch) pair).

  • Positional embeddings (STEEGFormer.temporal_pos, STEEGFormer.channel_pos). Operations: add a fixed sinusoidal temporal encoding over the seq patches and a learned channel embedding drawn from a shared montage vocabulary. Role: mark when (in time) and on which electrode each token sits, so the same electrode shares its embedding across datasets with different channel sets.

  • Transformer encoder (STEEGFormer.encoder). Operations: depth pre-norm ViT blocks (multi-head self-attention + MLP), reusing braindecode’s MultiHeadAttention and FeedForwardBlock. Role: mix information across all (channel, time-patch) tokens.

  • Read-out + head (STEEGFormer.norm, STEEGFormer.final_layer). Operations: "avg" mean-pools the patch tokens (CLS excluded); "cls" layer-normalises the sequence and takes the CLS token; a linear layer maps to n_outputs. Role: produce the class logits.

Temporal, Spatial, and Spectral Encoding

  • Temporal: non-overlapping temporal patches with a fixed sinusoidal position encoding over the seq patches.

  • Spatial (channels): a learned channel embedding indexed through a shared montage vocabulary of standard electrode positions.

  • Spectral: none explicit; frequency content is learned implicitly by the patch projection and self-attention.

Additional Mechanisms

A learned CLS token (sequence position 0) summarises the sequence for the "cls" read-out. Optional stochastic depth (drop_path) and dropout (drop_prob) regularise training; both default to 0 to match the released checkpoints.

Variants

The released variants differ in width/depth and, for largeV2, the channel-vocabulary size (patch_size=16, mlp_ratio=4 throughout):

Variant

embed_dim

depth

num_heads

n_chans_pos

small

512

8

8

145

base

768

12

12

145

large

1024

24

16

145

largeV2

1024

24

16

256

Pre-trained weights

Ready-to-use checkpoints are re-hosted on the Hugging Face Hub under the braindecode organization. These repos convert the official MAE encoder checkpoints to braindecode’s key names and include config.json plus model.safetensors/pytorch_model.bin:

Variant

Hub repo

Notes

small

braindecode/STEEGFormer-small

145-slot channel vocabulary

base

braindecode/STEEGFormer-base

145-slot channel vocabulary

large

braindecode/STEEGFormer-large

145-slot channel vocabulary

largeV2

braindecode/STEEGFormer-largeV2

256-slot HBN channel vocabulary

Use the regular Hub API to load a re-hosted checkpoint:

model = STEEGFormer.from_pretrained(
    "braindecode/STEEGFormer-small", n_outputs=4, n_chans=22
)

The re-hosted repos save complete braindecode model files, so they include a classification head tensor for serialization. Only the encoder weights are from the official MAE pretraining; pass n_outputs for the downstream task so the head is rebuilt as needed.

To regenerate the re-hosted files from the official GitHub checkpoints, run the standalone convert_steegformer_checkpoints.py archived in each Hub repo; the model itself loads braindecode-format state dicts, so from_pretrained needs no conversion.

Note

Numerical equivalence of the encoder features with the reference implementation has been verified on the released checkpoints. The channel-to-vocabulary mapping is resolved from the electrode names in chs_info (looked up in STEEGFORMER_CHANNEL_ORDER, the BENDR/LaBraM convention); when chs_info is absent or a name is unknown, it falls back to the identity mapping (channel i -> slot i) with a warning. Pass chan_pos_idx to override explicitly.

Parameters:
  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • n_chans (int) – Number of EEG channels.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • n_times (int) – Number of time samples of the input window.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • patch_size (int) – Temporal patch size (unfold), default 16.

  • embed_dim (int) – Token embedding dimension (512 / 768 / 1024 across variants).

  • depth (int) – Number of Transformer encoder blocks.

  • num_heads (int) – Number of attention heads.

  • mlp_ratio (int) – Hidden-to-embedding ratio of the MLP blocks.

  • drop_prob (float) – Dropout rate.

  • drop_path (float) – Stochastic-depth rate (max of a linear schedule over depth), default 0 (disabled, matching the released checkpoints).

  • activation (type[Module]) – Activation layer class used in the feed-forward blocks, default GELU.

  • global_pool (str) – Token aggregation before the head ("avg" or "cls").

  • n_chans_pos (int) – Size of the shared montage vocabulary the channel embedding is drawn from (145 for small/base/large, 256 for largeV2), default 145.

  • chan_pos_idx (array-like of int, optional) – Montage-vocabulary slot of each input channel, shape (n_chans,). If omitted, it is resolved from chs_info electrode names (falling back to range(n_chans)).

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.

References

[Yang2026]

Yang, L., Sun, Q., Li, A. & Van Hulle, M. M. (2026). Are EEG foundation models worth it? Comparative evaluation with traditional decoders in diverse BCI tasks. The Fourteenth International Conference on Learning Representations (ICLR 2026). https://openreview.net/forum?id=5Xwm8e6vbh

Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:

pip install braindecode[hub]

Pushing a model to the Hub:

from braindecode.models import STEEGFormer

# Train your model
model = STEEGFormer(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...

# Push to the Hub
model.push_to_hub(
    repo_id="username/my-steegformer-model",
    commit_message="Initial model upload",
)

Loading a model from the Hub:

from braindecode.models import STEEGFormer

# Load pretrained model
model = STEEGFormer.from_pretrained("username/my-steegformer-model")

# Load with a different number of outputs (head is rebuilt automatically)
model = STEEGFormer.from_pretrained("username/my-steegformer-model", n_outputs=4)

Extracting features and replacing the head:

import torch

x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]

# Replace the classification head
model.reset_head(n_outputs=10)

Saving and restoring full configuration:

import json

config = model.get_config()            # all __init__ params
with open("config.json", "w") as f:
    json.dump(config, f)

model2 = STEEGFormer.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See Loading and Adapting Pretrained Foundation Models for a complete tutorial.

Methods

forward(x, return_features=False)[source]#

Encode an EEG batch into class logits (or encoder features).

Parameters:
  • x (Tensor) – EEG input of shape (batch, n_chans, n_times).

  • return_features (bool) – If True, return the layer-normalised encoder tokens as {"features": patch_tokens, "cls_token": cls_token} instead of the class logits (the unified braindecode foundation-model API). The cls_token then matches the feature the "cls" head consumes.

Returns:

Class logits of shape (batch, n_outputs), or the feature dict {"features", "cls_token"} when return_features is set.

Return type:

torch.Tensor | dict

reset_head(n_outputs)[source]#

Replace the linear classification head for a new n_outputs.

Called by from_pretrained() when the requested number of outputs differs from the pre-trained checkpoint (whose head is discarded).

Parameters:

n_outputs (int) – New number of output classes.

Examples

>>> from braindecode.models import BENDR
>>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4)
>>> model.reset_head(10)
>>> model.n_outputs
10

Added in version 1.4.

Examples using braindecode.models.STEEGFormer#

Loading and Adapting Pretrained Foundation Models

Loading and Adapting Pretrained Foundation Models