braindecode.models.BENDR#

class braindecode.models.BENDR(*args, **kwargs)[source]#

BENDR (BErt-inspired Neural Data Representations) from Kostas et al. (2021) [bendr].

Convolution Large Brain Model

BENDR Architecture

The BENDR architecture adapts techniques used for language modeling (LM) toward the development of encephalography modeling (EM) [bendr]. It utilizes a self-supervised training objective to learn compressed representations of raw EEG signals [bendr]. The model is capable of modeling completely novel raw EEG sequences recorded with differing hardware and subjects, aiming for transferable performance across a variety of downstream BCI and EEG classification tasks [bendr].

Architectural Overview

BENDR is adapted from wav2vec 2.0 [wav2vec2] and is composed of two main stages: a feature extractor (Convolutional stage) that produces BErt-inspired Neural Data Representations (BENDR), followed by a transformer encoder (Contextualizer) [bendr].

Macro Components

  • BENDR.encoder (Convolutional Stage/Feature Extractor)
    • Operations. A stack of six short-receptive field 1D convolutions [bendr]. Each block consists of 1D convolution, GroupNorm, and GELU activation.

    • Role. Takes raw data \(X_{raw}\) and dramatically downsamples it to a new sequence of vectors (BENDR) [bendr]. Each resulting vector has a length of 512.

  • BENDR.contextualizer (Transformer Encoder)
    • Operations. A transformer encoder that uses layered, multi-head self-attention [bendr]. It employs T-Fixup weight initialization [tfixup] and uses 8 layers and 8 heads.

    • Role. Maps the sequence of BENDR vectors to a contextualized sequence. The output of a fixed start token is typically used as the aggregate representation for downstream classification [bendr].

  • Contextualizer.position_encoder (Positional Encoding)
    • Operations. An additive (grouped) convolution layer with a receptive field of 25 and 16 groups [bendr].

    • Role. Encodes position information before the input enters the transformer.

How the information is encoded temporally, spatially, and spectrally

  • Temporal. The convolutional encoder uses a stack of blocks where the stride matches the receptive field (e.g., 3 for the first block, 2 for subsequent blocks) [bendr]. This process downsamples the raw data by a factor of 96, resulting in an effective sampling frequency of approximately 2.67 Hz.

  • Spatial. To maintain simplicity and reduce complexity, the convolutional stage uses 1D convolutions and elects not to mix EEG channels across the first stage [bendr]. The input includes 20 channels (19 EEG channels and one relative amplitude channel).

  • Spectral. The convolution operations implicitly extract features from the raw EEG signal [bendr]. The representations (BENDR) are derived from the raw waveform using convolutional operations followed by sequence modeling [wav2vec2].

Additional Mechanisms

  • Self-Supervision (Pre-training). Uses a masked sequence learning approach (adapted from wav2vec 2.0 [wav2vec2]) where contiguous spans of BENDR sequences are masked, and the model attempts to reconstruct the original underlying encoded vector based on the transformer output and a set of negative distractors [bendr].

  • Regularization. LayerDrop [layerdrop] and Dropout (at probabilities 0.01 and 0.15, respectively) are used during pre-training [bendr]. The implementation also uses T-Fixup scaling for parameter initialization [tfixup].

  • Input Conditioning. A fixed token (a vector filled with the value -5) is prepended to the BENDR sequence before input to the transformer, serving as the aggregate representation token [bendr].

Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • n_times (int) – Number of time samples of the input window.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • encoder_h (int, default=512) – Hidden size (number of output channels) of the convolutional encoder. This determines the dimensionality of the BENDR feature vectors produced by the encoder.

  • contextualizer_hidden (int, default=3076) – Hidden size of the feedforward layer within each transformer block. The paper uses approximately 2x the transformer dimension (3076 ~ 2 x 1536).

  • projection_head (bool, default=False) – If True, adds a projection layer at the end of the encoder to project back to the input feature size. This is used during self-supervised pre-training but typically disabled during fine-tuning.

  • drop_prob (float, default=0.1) – Dropout probability applied throughout the model. The paper recommends 0.15 for pre-training and 0.0 for fine-tuning. Default is 0.1 as a compromise.

  • layer_drop (float, default=0.0) – Probability of dropping entire transformer layers during training (LayerDrop regularization [layerdrop]). The paper uses 0.01 for pre-training and 0.0 for fine-tuning.

  • activation (torch.nn.Module, default=:class:torch.nn.GELU) – Activation function used in the encoder convolutional blocks. The paper uses GELU activation throughout.

  • transformer_layers (int, default=8) – Number of transformer encoder layers in the contextualizer. The paper uses 8 layers.

  • transformer_heads (int, default=8) – Number of attention heads in each transformer layer. The paper uses 8 heads with head dimension of 192 (1536 / 8).

  • position_encoder_length (int, default=25) – Kernel size for the convolutional positional encoding layer. The paper uses a receptive field of 25 with 16 groups.

  • enc_width (tuple of int, default=(3, 2, 2, 2, 2, 2)) – Kernel sizes for each of the 6 convolutional blocks in the encoder. Each value corresponds to one block.

  • enc_downsample (tuple of int, default=(3, 2, 2, 2, 2, 2)) – Stride values for each of the 6 convolutional blocks in the encoder. The total downsampling factor is the product of all strides (3 x 2 x 2 x 2 x 2 x 2 = 96).

  • start_token (int or float, default=-5) – Value used to fill the start token embedding that is prepended to the BENDR sequence before input to the transformer. This token’s output is used as the aggregate representation for classification.

  • final_layer (bool, default=True) – If True, includes a final linear classification layer that maps from encoder_h to n_outputs. If False, the model outputs the contextualized features directly.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

  • The full BENDR architecture contains a large number of parameters; configuration (1) involved training over one billion parameters [bendr].

  • Randomly initialized full BENDR architecture was generally ineffective at solving downstream tasks without prior self-supervised training [bendr].

  • The pre-training task (contrastive predictive coding via masking) is generalizable, exhibiting strong uniformity of performance across novel subjects, hardware, and tasks [bendr].

Warning

Important: To utilize the full potential of BENDR, the model requires self-supervised pre-training on large, unlabeled EEG datasets (like TUEG) followed by subsequent fine-tuning on the specific downstream classification task [bendr].

References

[bendr] (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20)

Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021). BENDR: Using transformers and a contrastive self-supervised learning task to learn from massive amounts of EEG data. Frontiers in Human Neuroscience, 15, 653659. https://doi.org/10.3389/fnhum.2021.653659

[wav2vec2] (1,2,3)

Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020). wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds), Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460). https://dl.acm.org/doi/10.5555/3495724.3496768

[tfixup] (1,2)

Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020). Improving Transformer Optimization Through Better Initialization. In International Conference on Machine Learning (pp. 4475-4483). PMLR. https://dl.acm.org/doi/10.5555/3524938.3525354

[layerdrop] (1,2)

Fan, A., Grave, E., & Joulin, A. (2020). Reducing Transformer Depth on Demand with Structured Dropout. International Conference on Learning Representations. Retrieved from https://openreview.net/forum?id=SylO2yStDr

Methods

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x – The description is missing.