braindecode.models.EEGConformer#

class braindecode.models.EEGConformer(n_outputs=None, n_chans=None, n_filters_time=40, filter_time_length=25, pool_time_length=75, pool_time_stride=15, drop_prob=0.5, att_depth=6, att_heads=10, att_drop_prob=0.5, final_fc_length='auto', return_features=False, activation=<class 'torch.nn.modules.activation.ELU'>, activation_transfor=<class 'torch.nn.modules.activation.GELU'>, n_times=None, chs_info=None, input_window_seconds=None, sfreq=None)[source]#

EEG Conformer from Song et al. (2022) [song2022].

Convolution Small Attention

EEGConformer Architecture

Architectural Overview

EEG-Conformer is a convolution-first model augmented with a lightweight transformer encoder. The end-to-end flow is:

  • (i) _PatchEmbedding converts the continuous EEG into a compact sequence of tokens via a ShallowFBCSPNet temporal–spatial conv stem and temporal pooling;

  • (ii) _TransformerEncoder applies small multi-head self-attention to integrate longer-range temporal context across tokens;

  • (iii) _ClassificationHead aggregates the sequence and performs a linear readout. This preserves the strong inductive biases of shallow CNN filter banks while adding just enough attention to capture dependencies beyond the pooling horizon [song2022].

Macro Components

  • _PatchEmbedding (Shallow conv stem → tokens)

    • Operations.

    • A temporal convolution (:class:torch.nn.Conv2d) (1 x L_t) forms a data-driven “filter bank”;

    • A spatial convolution (:class:torch.nn.Conv2d) (n_chans x 1)`` projects across electrodes, collapsing the channel axis into a virtual channel.

    • Normalization function torch.nn.BatchNorm

    • Activation function torch.nn.ELU

    • Average Pooling torch.nn.AvgPool along time (kernel (1, P) with stride (1, S))

    • final 1x1 torch.nn.Linear projection.

The result is rearranged to a token sequence (B, S_tokens, D), where D = n_filters_time.

Interpretability/robustness. Temporal kernels can be inspected as FIR filters; the spatial conv yields channel projections analogous to ShallowFBCSPNet’s learned spatial filters. Temporal pooling stabilizes statistics and reduces sequence length.

  • _TransformerEncoder (context over temporal tokens)

    • Operations.

    • A stack of att_depth encoder blocks. _TransformerEncoderBlock

    • Each block applies LayerNorm torch.nn.LayerNorm

    • Multi-Head Self-Attention (att_heads) with dropout + residual MultiHeadAttention (torch.nn.Dropout)

    • LayerNorm torch.nn.LayerNorm

    • 2-layer feed-forward (≈4x expansion, torch.nn.GELU) with dropout + residual.

Shapes remain (B, S_tokens, D) throughout.

Role. Small attention focuses on interactions among temporal patches (not channels), extending effective receptive fields at modest cost.

With return_features=True, features before the last Linear can be exported for linear probing or downstream tasks.

Convolutional Details

  • Temporal (where time-domain patterns are learned).

    The initial (1 x L_t) conv per channel acts as a learned filter bank for oscillatory bands and transients. Subsequent AvgPool along time performs local integration, converting activations into “patches” (tokens). Pool length/stride control the token rate and set the lower bound on temporal context within each token.

  • Spatial (how electrodes are processed).

    A single conv with kernel (n_chans x 1) spans the full montage to learn spatial projections for each temporal feature map, collapsing the channel axis into a virtual channel before tokenization. This mirrors the shallow spatial step in ShallowFBCSPNet (temporal filters → spatial projection → temporal condensation).

  • Spectral (how frequency content is captured).

    No explicit Fourier/wavelet stage is used. Spectral selectivity emerges implicitly from the learned temporal kernels; pooling further smooths high-frequency noise. The effective spectral resolution is thus governed by L_t and the pooling configuration.

Attention / Sequential Modules

  • Type. Standard multi-head self-attention (MHA) with att_heads heads over the token sequence.

  • Shapes. Input/Output: (B, S_tokens, D); attention operates along the S_tokens axis.

  • Role. Re-weights and integrates evidence across pooled windows, capturing dependencies longer than any single token while leaving channel relationships to the convolutional stem. The design is intentionally small—attention refines rather than replaces convolutional feature extraction.

Additional Mechanisms

  • Parallel with ShallowFBCSPNet. Both begin with a learned temporal filter bank,

    spatial projection across electrodes, and early temporal condensation. ShallowFBCSPNet then computes band-power (via squaring/log-variance), whereas EEG-Conformer applies BN/ELU and continues with attention over tokens to refine temporal context before classification.

  • Tokenization knob. pool_time_length and especially pool_time_stride set

    the number of tokens S_tokens. Smaller strides → more tokens and higher attention capacity (but higher compute); larger strides → fewer tokens and stronger inductive bias.

  • Embedding dimension = filters. n_filters_time serves double duty as both the

    number of temporal filters in the stem and the transformer’s embedding size D, simplifying dimensional alignment.

Usage and Configuration

  • Instantiation. Choose n_filters_time (embedding size D) and

    filter_time_length to match the rhythms of interest. Tune pool_time_length/stride to trade temporal resolution for sequence length. Keep att_depth modest (e.g., 4–6) and set att_heads to divide D. final_fc_length="auto" infers the flattened size from PatchEmbedding.

Parameters:
  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • n_chans (int) – Number of EEG channels.

  • n_filters_time (int) – Number of temporal filters, defines also embedding size.

  • filter_time_length (int) – Length of the temporal filter.

  • pool_time_length (int) – Length of temporal pooling filter.

  • pool_time_stride (int) – Length of stride between temporal pooling filters.

  • drop_prob (float) – Dropout rate of the convolutional layer.

  • att_depth (int) – Number of self-attention layers.

  • att_heads (int) – Number of attention heads.

  • att_drop_prob (float) – Dropout rate of the self-attention layer.

  • final_fc_length (int | str) – The dimension of the fully connected layer.

  • return_features (bool) – If True, the forward method returns the features before the last classification layer. Defaults to False.

  • activation (nn.Module) – Activation function as parameter. Default is nn.ELU

  • activation_transfor (nn.Module) – Activation function as parameter, applied at the FeedForwardBlock module inside the transformer. Default is nn.GeLU

  • n_times (int) – Number of time samples of the input window.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

The authors recommend using data augmentation before using Conformer, e.g. segmentation and recombination, Please refer to the original paper and code for more details [ConformerCode].

The model was initially tuned on 4 seconds of 250 Hz data. Please adjust the scale of the temporal convolutional layer, and the pooling layer for better performance.

Added in version 0.8.

We aggregate the parameters based on the parts of the models, or when the parameters were used first, e.g. n_filters_time.

Added in version 1.1.

References

[song2022] (1,2)

Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG conformer: Convolutional transformer for EEG decoding and visualization. IEEE Transactions on Neural Systems and Rehabilitation Engineering, 31, pp.710-719. https://ieeexplore.ieee.org/document/9991178

[ConformerCode]

Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG conformer: Convolutional transformer for EEG decoding and visualization. eeyhsong/EEG-Conformer.

Methods

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor) – The description is missing.

Return type:

Tensor

get_fc_size()[source]#