braindecode.models.BrainModule#

class braindecode.models.BrainModule(n_chans=None, n_outputs=None, n_times=None, sfreq=None, chs_info=None, input_window_seconds=None, hidden_dim=320, depth=10, kernel_size=3, growth=1.0, dilation_growth=2, dilation_period=5, conv_drop_prob=0.0, dropout_input=0.0, batch_norm=True, activation=<class 'torch.nn.modules.activation.GELU'>, n_subjects=200, subject_dim=0, subject_layers=False, subject_layers_dim='input', subject_layers_id=False, embedding_scale=1.0, n_fft=None, fft_complex=True, channel_dropout_prob=0.0, channel_dropout_type=None, glu=2, glu_context=1)[source]#

BrainModule from [brainmagick], also known as SimpleConv.

A dilated convolutional encoder for EEG decoding, using residual connections and optional GLU gating for improved expressivity.

Convolution

BrainModule Architecture

Figure adapted Extended Data Fig. 4 from [brainmagick] to highlight only the model part. Architecture of the brain module. Architecture used to process the brain recordings. For each layer, the authors note first the number of output channels, while the number of time steps is constant throughout the layers. The model is composed of a spatial attention layer, then a 1x1 convolution without activation. A ‘Subject Layer’ is selected based on the subject index s, which consists in a 1x1 convolution learnt only for that subject with no activation. Then, the authors apply five convolutional blocks made of three convolutions. The first two use residual skip connection and increasing dilation, followed by a BatchNorm layer and a GELU activation. The third convolution is not residual, and uses a GLU activation (which halves the number of channels) and no normalization. Finally, the authors apply two 1x1 convolutions with a GELU in between.#

The BrainModule (also referred to as SimpleConv) is a deep dilated convolutional encoder specifically designed to decode perceived speech from non-invasive brain recordings like EEG and MEG. It is engineered to address the high noise levels and inter-individual variability inherent in non-invasive neuroimaging by using a single architecture trained across large cohorts while accommodating participant-specific differences.

Architecture Overview

The BrainModule integrates three primary mechanisms to align brain activity with deep speech representations:

  1. Spatial-temporal feature extraction. The model uses a dedicated spatial attention layer to remap sensor data based on physical locations, followed by temporal processing through dilated convolutions.

  2. Subject-specific adaptation. To leverage inter-subject variability, the architecture includes a “Subject Layer” or participant-specific 1x1 convolution that allows the model to share core weights across a cohort while learning individual-specific neural patterns.

  3. Dilated residual blocks with gating. The core encoder employs a stack of convolutional blocks featuring skip connections and increasing dilation to expand the receptive field without losing temporal resolution, supplemented by optional Gated Linear Units (GLU) for increased expressivity.

Macro Components

BrainModule.input_projection (Initial Processing)

Operations. Raw M/EEG input \(\mathbf{X} \in \mathbb{R}^{C \times T}\) is first processed through a spatial attention layer that projects sensor locations onto a 2D plane using Fourier-parameterized functions. This is followed by a subject-specific 1x1 convolution \(\mathbf{M}_s \in \mathbb{R}^{D_1 \times D_1}\) if subject features are enabled. The resulting features are projected to the hidden_dim (default 320) to ensure compatibility with subsequent residual connections.

Role. Converts high-dimensional, subject-dependent sensor data into a standardized latent space while preserving spatial and temporal relationships.

BrainModule.encoder (Convolutional Sequence)

Operations. Implemented via _ConvSequence, this component consists of a stack of k convolutional blocks. Each block typically contains: (a) Residual dilated convolutions. Two layers with kernel size 3, residual skip connections, and dilation factors that grow exponentially (e.g., powers of two with periodic resets) to capture multi-scale temporal context. (b) GLU gating. Every N layers (defined by glu), a Gated Linear Unit is applied, which halves the channel dimension and introduces non-linear gating to filter intermediate representations.

Role. Extracts deep hierarchical temporal features from the brain signal, significantly expanding the model’s receptive field to align with the contextual windows of speech modules like wav2vec 2.0.

Temporal, Spatial, and Spectral Encoding

  • Temporal: Increasing dilation factors across layers allow the model to integrate information over large time windows without the computational cost of standard large kernels, while a 150 ms input shift facilitates alignment between stimulus and brain response.

  • Spatial: The spatial attention layer learns a softmax weighting over input sensors based on their 3D coordinates, allowing the model to focus on regions typically activated during auditory stimulation (e.g., the temporal cortex).

  • Spectral: Through the optional n_fft parameter, the model can apply an STFT transformation, converting time-domain signals into a spectrogram representation before encoding.

Additional Mechanisms

  • Clamping and scaling: The model relies on clamping input values (e.g., at 20 standard deviations) to prevent outliers and large electromagnetic artifacts from destabilizing the BatchNorm estimates and optimization process.

  • Scaled subject embeddings: When subject_dim is used, the _ScaledEmbedding layer scales up the learning rate for subject-specific features to prevent slow convergence in multi-participant training.

  • _ConvSequence and residual logic: This class handles the actual stacking of layers. It is designed to be flexible with the growth parameter; if the channel size changes between layers (growth != 1.0), it automatically applies a 1x1 skip_projection convolution to the residual path so dimensions match for addition.

  • _ChannelDropout: Unlike standard dropout which zeroes individual neurons, this zeroes entire channels. It includes a rescale feature that multiplies the remaining channels by a factor total_channels / active_channels to maintain the expected value of the signal during training.

  • _ScaledEmbedding: This is a clever optimization for multi-subject learning. By dividing the initial weights by a scale and then multiplying the output by the same scale, it effectively increases the gradient magnitude for the embedding weights, allowing subject-specific features to learn faster than the shared backbone.

Parameters:
  • hidden_dim (int, default=320) – Hidden dimension for convolutional layers. Input is projected to this dimension before the convolutional blocks.

  • depth (int, default=10) – Number of convolutional blocks. Each block contains a dilated convolution with batch normalization and activation, followed by a residual connection.

  • kernel_size (int, default=3) – Convolutional kernel size. Must be odd for proper padding with dilation.

  • growth (float, default=1.0) – Channel size multiplier: hidden_dim * (growth ** layer_index). Values > 1.0 grow channels deeper; < 1.0 shrink them. Note: growth != 1.0 disables residual connections between layers with different channel sizes.

  • dilation_growth (int, default=2) – Dilation multiplier per layer (e.g., 2 means dilation doubles each layer). Improves receptive field exponentially. Requires odd kernel_size.

  • dilation_period (int, default=5) – Reset dilation to 1 every N layers. Prevents dilation from growing too large and maintains local connectivity.

  • conv_drop_prob (float, default=0.0) – Dropout probability for convolutional layers.

  • dropout_input (float, default=0.0) – Dropout probability applied to model input only.

  • batch_norm (bool, default=True) – If True, apply batch normalization after each convolution.

  • activation (type[nn.Module], default=nn.GELU) – Activation function class to use (e.g., nn.GELU, nn.ReLU, nn.ELU).

  • n_subjects (int, default=200) – Number of unique subjects (for subject-specific pathways). Only used if subject_dim > 0.

  • subject_dim (int, default=0) – Dimension of subject embeddings. If 0, no subject-specific features. If > 0, adds subject embeddings to the input before encoding.

  • subject_layers (bool, default=False) – If True, apply subject-specific linear transformations to input channels. Each subject has its own weight matrix. Requires subject_dim > 0.

  • subject_layers_dim (str, default="input") – Where to apply subject layers: “input” or “hidden”.

  • subject_layers_id (bool, default=False) – If True, initialize subject layers as identity matrices.

  • embedding_scale (float, default=1.0) – Scaling factor for subject embeddings learning rate.

  • n_fft (int, optional) – FFT size for STFT processing. If None, no STFT is applied. If specified, applies spectrogram transform before encoding.

  • fft_complex (bool, default=True) – If True, keep complex spectrogram. If False, use power spectrogram. Only used when n_fft is not None.

  • channel_dropout_prob (float, default=0.0) – Probability of dropping each channel during training (0.0 to 1.0). If 0.0, no channel dropout is applied.

  • channel_dropout_type (str, optional) – If specified with chs_info, only drop channels of this type (e.g., ‘eeg’, ‘ref’, ‘eog’). If None with dropout_prob > 0, drops any channel.

  • glu (int, default=2) – If > 0, applies Gated Linear Units (GLU) every N convolutional layers. GLUs gate intermediate representations for more expressivity. If 0, no GLU is applied.

  • glu_context (int, default=1) – Context window size for GLU gates. If > 0, uses contextual information from neighboring time steps for gating. Requires glu > 0.

References

[brainmagick] (1,2)

Défossez, A., Caucheteux, C., Rapin, J., Kabeli, O., & King, J. R. (2023). Decoding speech perception from non-invasive brain recordings. Nature Machine Intelligence, 5(10), 1097-1107.

Notes

  • Input shape: (batch, n_chans, n_times)

  • Output shape: (batch, n_outputs)

  • The model uses dilated convolutions with stride=1 to maintain temporal resolution while achieving large receptive fields.

  • Residual connections are applied at every layer where input and output channels match.

  • Subject-specific features (subject_dim > 0, subject_layers) require passing subject indices in the forward pass as an optional parameter or via batch.

  • STFT processing (n_fft > 0) automatically transforms input to spectrogram domain.

Added in version 1.2.

Methods

forward(x, subject_index=None)[source]#

Forward pass.

Parameters:
  • x (torch.Tensor) – Input EEG data of shape (batch, n_chans, n_times).

  • subject_index (torch.Tensor, optional) – Subject indices of shape (batch,). Required if subject_dim > 0.

Returns:

Output logits/predictions of shape (batch, n_outputs).

Return type:

torch.Tensor