
class braindecode.models.AttentionBaseNet(n_times=None, n_chans=None, n_outputs=None, chs_info=None, sfreq=None, input_window_seconds=None, n_temporal_filters: int = 40, temp_filter_length_inp: int = 25, spatial_expansion: int = 1, pool_length_inp: int = 75, pool_stride_inp: int = 15, dropout_inp: float = 0.5, ch_dim: int = 16, temp_filter_length: int = 15, pool_length: int = 8, pool_stride: int = 8, dropout: float = 0.5, attention_mode: str | None = None, reduction_rate: int = 4, use_mlp: bool = False, freq_idx: int = 0, n_codewords: int = 4, kernel_size: int = 9, activation: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ELU'>, extra_params: bool = False)[source]#

AttentionBaseNet from Wimpff M et al. (2023) [Martin2023].

Neural Network from the paper: EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms

The paper and original code with more details about the methodological choices are available at the [Martin2023] and [MartinCode].

The AttentionBaseNet architecture is composed of four modules: - Input Block that performs a temporal convolution and a spatial convolution. - Channel Expansion that modifies the number of channels. - An attention block that performs channel attention with several options - ClassificationHead

Added in version 0.9.

  • n_times (int) – Number of time samples of the input window.

  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • input_window_seconds (float) – Length of the input window in seconds.

  • n_temporal_filters (int, optional) – Number of temporal convolutional filters in the first layer. This defines the number of output channels after the temporal convolution. Default is 40.

  • temp_filter_length_inp – The description is missing.

  • spatial_expansion (int, optional) – Multiplicative factor to expand the spatial dimensions. Used to increase the capacity of the model by expanding spatial features. Default is 1.

  • pool_length_inp (int, optional) – Length of the pooling window in the input layer. Determines how much temporal information is aggregated during pooling. Default is 75.

  • pool_stride_inp (int, optional) – Stride of the pooling operation in the input layer. Controls the downsampling factor in the temporal dimension. Default is 15.

  • dropout_inp (float, optional) – Dropout rate applied after the input layer. This is the probability of zeroing out elements during training to prevent overfitting. Default is 0.5.

  • ch_dim (int, optional) – Number of channels in the subsequent convolutional layers. This controls the depth of the network after the initial layer. Default is 16.

  • temp_filter_length (int, default=15) – The length of the temporal filters in the convolutional layers.

  • pool_length (int, default=8) – The length of the window for the average pooling operation.

  • pool_stride (int, default=8) – The stride of the average pooling operation.

  • dropout (float, default=0.5) – The dropout rate for regularization. Values should be between 0 and 1.

  • attention_mode (str, optional) – The type of attention mechanism to apply. If None, no attention is applied. - “se” for Squeeze-and-excitation network - “gsop” for Global Second-Order Pooling - “fca” for Frequency Channel Attention Network - “encnet” for context encoding module - “eca” for Efficient channel attention for deep convolutional neural networks - “ge” for Gather-Excite - “gct” for Gated Channel Transformation - “srm” for Style-based Recalibration Module - “cbam” for Convolutional Block Attention Module - “cat” for Learning to collaborate channel and temporal attention from multi-information fusion - “catlite” for Learning to collaborate channel attention from multi-information fusion (lite version, cat w/o temporal attention)

  • reduction_rate (int, default=4) – The reduction rate used in the attention mechanism to reduce dimensionality and computational complexity.

  • use_mlp (bool, default=False) – Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within the attention mechanism for further processing.

  • freq_idx (int, default=0) – DCT index used in fca attention mechanism.

  • n_codewords (int, default=4) – The number of codewords (clusters) used in attention mechanisms that employ quantization or clustering strategies.

  • kernel_size (int, default=9) – The kernel size used in certain types of attention mechanisms for convolution operations.

  • activation (nn.Module, default=nn.ELU) – Activation function class to apply. Should be a PyTorch activation module class like nn.ReLU or nn.ELU. Default is nn.ELU.

  • extra_params (bool, default=False) – Flag to indicate whether additional, custom parameters should be passed to the attention mechanism.

  • ValueError – If some input signal-related parameters are not specified: and can not be inferred.

  • FutureWarning – If add_log_softmax is True, since LogSoftmax final layer: will be removed in the future.


If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.


[Martin2023] (1,2)

Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms. arXiv preprint arXiv:2310.11198.


Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. GitHub martinwimpff/channel-attention (accessed 2024-03-28)



x – The description is missing.