braindecode.models.ATCNet#

class braindecode.models.ATCNet(n_chans=None, n_outputs=None, input_window_seconds=None, sfreq=250.0, conv_block_n_filters=16, conv_block_kernel_length_1=64, conv_block_kernel_length_2=16, conv_block_pool_size_1=8, conv_block_pool_size_2=7, conv_block_depth_mult=2, conv_block_dropout=0.3, n_windows=5, att_head_dim=8, att_num_heads=2, att_drop_prob=0.5, tcn_depth=2, tcn_kernel_size=4, tcn_drop_prob=0.3, tcn_activation=<class 'torch.nn.modules.activation.ELU'>, concat=False, max_norm_const=0.25, chs_info=None, n_times=None)[source]#

ATCNet from Altaheri et al. (2022) [1].

Convolution Small Attention

ATCNet Architecture

Architectural Overview

ATCNet is a convolution-first architecture augmented with a lightweight attention–TCN sequence module. The end-to-end flow is:

  • (i) _ConvBlock learns temporal filter-banks and spatial projections (EEGNet-style), downsampling time to a compact feature map;

    1. Sliding Windows carve overlapping temporal windows from this map;

  • (iii) for each window, _AttentionBlock applies small multi-head self-attention over time, followed by a _TCNResidualBlock stack (causal, dilated);

  • (iv) window-level features are aggregated (mean of window logits or concatenation) and mapped via a max-norm–constrained linear layer.

Relative to ViT, ATCNet replaces linear patch projection with learned temporal–spatial convolutions; it processes parallel window encoders (attention→TCN) instead of a deep stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.

Macro Components

  • _ConvBlock (Shallow conv stem → feature map)

    • Operations.

    • Temporal conv (torch.nn.Conv2d) with kernel (L_t, 1) builds a

      FIR-like filter bank (F1 maps).

    • Depthwise spatial conv (torch.nn.Conv2d, groups=F1) with kernel (1, n_chans) learns per-filter spatial projections (akin to EEGNet’s CSP-like step).

    • BN → ELU → AvgPool → Dropout to stabilize and condense activations.

    • Refining temporal conv (torch.nn.Conv2d) with kernel (L_r, 1) + BN → ELU → AvgPool → Dropout.

The output shape is (B, F2, T_c, 1) with F2 = F1·D and T_c = T/(P1·P2). Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific topographies. Pooling acts as a local integrator, reducing variance and imposing a useful inductive bias on short EEG windows.

  • Sliding-Window Sequencer

    From the condensed time axis (length T_c), ATCNet forms n overlapping windows of width T_w = T_c - n + 1 (one start per index). Each window produces a sequence (B, F2, T_w) forwarded to its own attention–TCN branch. This creates parallel encoders over shifted contexts and is key to robustness on nonstationary EEG.

  • _AttentionBlock (small MHA on temporal positions)

    • Operations.

    • Rearrange to (B, T_w, F2),

    • Normalization torch.nn.LayerNorm

    • Custom MultiHeadAttention _MHA (num_heads=H, per-head dim d_h) + residual add,

    • Dropout torch.nn.Dropout

    • Rearrange back to (B, F2, T_w).

Note: Attention is local to a window and purely temporal.

Role. Re-weights evidence across the window, letting the model emphasize informative segments (onsets, bursts) before causal convolutions aggregate history.

  • _TCNResidualBlock (causal dilated temporal CNN)

    • Operations.

    • Two braindecode.modules.CausalConv1d layers per block with dilation 1, 2, 4,

    • Across blocks of torch.nn.ELU + torch.nn.BatchNorm1d + torch.nn.Dropout) + a residual (identity or 1x1 mapping).

    • The final feature used per window is the last causal step [..., -1] (forecast-style).

Role. Efficient long-range temporal integration with stable gradients; the dilated receptive field complements attention’s soft selection.

  • Aggregation & Classifier

    and average across windows (default, matching official code), or - (b) concatenate all window features (B, n·F2) and apply a single MaxNormLinear. The max-norm constraint regularizes the readout.

Convolutional Details

  • Temporal. Temporal structure is learned in three places:
      1. the stem’s wide (L_t, 1) conv (learned filter bank),

      1. the refining (L_r, 1) conv after pooling (short-term dynamics), and

    • (3) the TCN’s causal 1-D convolutions with exponentially increasing dilation (long-range dependencies). The minimum sequence length required by the TCN stack is (K_t - 1)·2^{L-1} + 1; the implementation auto-scales kernels/pools/windows when inputs are shorter to preserve feasibility.

  • Spatial. A depthwise spatial conv spans the full montage (kernel (1, n_chans)),

    producing per-temporal-filter spatial projections (no cross-filter mixing at this step). This mirrors EEGNet’s interpretability: each temporal filter has its own spatial pattern.

Attention / Sequential Modules

  • Type. Multi-head self-attention with H heads and per-head dim d_h implemented in _MHA, allowing embed_dim = H·d_h independent of input and output dims.

  • Shapes. (B, F2, T_w) (B, T_w, F2) (B, F2, T_w). Attention operates along the temporal axis within a window; channels/features stay in the embedding dim F2.

  • Role. Highlights salient temporal positions prior to causal convolution; small attention keeps compute modest while improving context modeling over pooled features.

Additional Mechanisms

  • Parallel encoders over shifted windows. Improves montage/phase robustness by ensembling nearby contexts rather than committing to a single segmentation.

  • Max-norm classifier. Enforces weight norm constraints at the readout, a common stabilization trick in EEG decoding.

  • ViT vs. ATCNet (design choices). Convolutional nonlinear projection rather than linear patchification; attention followed by TCN (not MLP); parallel window encoders rather than stacked encoders.

Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • input_window_seconds (float, optional) – Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a dataset.

  • sfreq (int, optional) – Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in BCI-IV 2a dataset.

  • conv_block_n_filters (int) – Number temporal filters in the first convolutional layer of the convolutional block, denoted F1 in figure 2 of the paper [1]. Defaults to 16 as in [1].

  • conv_block_kernel_length_1 (int) – Length of temporal filters in the first convolutional layer of the convolutional block, denoted Kc in table 1 of the paper [1]. Defaults to 64 as in [1].

  • conv_block_kernel_length_2 (int) – Length of temporal filters in the last convolutional layer of the convolutional block. Defaults to 16 as in [1].

  • conv_block_pool_size_1 (int) – Length of first average pooling kernel in the convolutional block. Defaults to 8 as in [1].

  • conv_block_pool_size_2 (int) – Length of first average pooling kernel in the convolutional block, denoted P2 in table 1 of the paper [1]. Defaults to 7 as in [1].

  • conv_block_depth_mult (int) – Depth multiplier of depthwise convolution in the convolutional block, denoted D in table 1 of the paper [1]. Defaults to 2 as in [1].

  • conv_block_dropout (float) – Dropout probability used in the convolution block, denoted pc in table 1 of the paper [1]. Defaults to 0.3 as in [1].

  • n_windows (int) – Number of sliding windows, denoted n in [1]. Defaults to 5 as in [1].

  • att_head_dim (int) – Embedding dimension used in each self-attention head, denoted dh in table 1 of the paper [1]. Defaults to 8 as in [1].

  • att_num_heads (int) – Number of attention heads, denoted H in table 1 of the paper [1]. Defaults to 2 as in [1].

  • att_drop_prob – The description is missing.

  • tcn_depth (int) – Depth of Temporal Convolutional Network block (i.e. number of TCN Residual blocks), denoted L in table 1 of the paper [1]. Defaults to 2 as in [1].

  • tcn_kernel_size (int) – Temporal kernel size used in TCN block, denoted Kt in table 1 of the paper [1]. Defaults to 4 as in [1].

  • tcn_drop_prob – The description is missing.

  • tcn_activation (torch.nn.Module) – Nonlinear activation to use. Defaults to nn.ELU().

  • concat (bool) – When True, concatenates each slidding window embedding before feeding it to a fully-connected layer, as done in [1]. When False, maps each slidding window to n_outputs logits and average them. Defaults to False contrary to what is reported in [1], but matching what the official code does [2].

  • max_norm_const (float) – Maximum L2-norm constraint imposed on weights of the last fully-connected layer. Defaults to 0.25.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • n_times (int) – Number of time samples of the input window.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

  • Inputs substantially shorter than the implied minimum length trigger automatic downscaling of kernels, pools, windows, and TCN kernel size to maintain validity.

  • The attention–TCN sequence operates per window; the last causal step is used as the window feature, aligning the temporal semantics across windows.

Added in version 1.1:

  • More detailed documentation of the model.

References

[1] (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25)

H. Altaheri, G. Muhammad, M. Alsulaiman (2022). Physics-informed attention temporal convolutional network for EEG-based motor imagery classification. IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.

[2]

Official EEG-ATCNet implementation (TensorFlow): Altaheri/EEG-ATCNet

Methods

forward(X)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

X – The description is missing.