braindecode.models.ATCNet#
- class braindecode.models.ATCNet(n_chans=None, n_outputs=None, input_window_seconds=None, sfreq=250.0, conv_block_n_filters=16, conv_block_kernel_length_1=64, conv_block_kernel_length_2=16, conv_block_pool_size_1=8, conv_block_pool_size_2=7, conv_block_depth_mult=2, conv_block_dropout=0.3, n_windows=5, att_head_dim=8, att_num_heads=2, att_drop_prob=0.5, tcn_depth=2, tcn_kernel_size=4, tcn_drop_prob=0.3, tcn_activation=<class 'torch.nn.modules.activation.ELU'>, concat=False, max_norm_const=0.25, chs_info=None, n_times=None)[source]#
ATCNet from Altaheri et al. (2022) [1].
Convolution Small Attention
Architectural Overview
ATCNet is a convolution-first architecture augmented with a lightweight attention–TCN sequence module. The end-to-end flow is:
(i)
_ConvBlock
learns temporal filter-banks and spatial projections (EEGNet-style), downsampling time to a compact feature map;Sliding Windows carve overlapping temporal windows from this map;
(iii) for each window,
_AttentionBlock
applies small multi-head self-attention over time, followed by a_TCNResidualBlock
stack (causal, dilated);(iv) window-level features are aggregated (mean of window logits or concatenation) and mapped via a max-norm–constrained linear layer.
Relative to ViT, ATCNet replaces linear patch projection with learned temporal–spatial convolutions; it processes parallel window encoders (attention→TCN) instead of a deep stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.
Macro Components
_ConvBlock
(Shallow conv stem → feature map)Operations.
- Temporal conv (
torch.nn.Conv2d
) with kernel(L_t, 1)
builds a FIR-like filter bank (
F1
maps).
- Temporal conv (
Depthwise spatial conv (
torch.nn.Conv2d
,groups=F1
) with kernel(1, n_chans)
learns per-filter spatial projections (akin to EEGNet’s CSP-like step).BN → ELU → AvgPool → Dropout to stabilize and condense activations.
Refining temporal conv (
torch.nn.Conv2d
) with kernel(L_r, 1)
+ BN → ELU → AvgPool → Dropout.
The output shape is
(B, F2, T_c, 1)
withF2 = F1·D
andT_c = T/(P1·P2)
. Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific topographies. Pooling acts as a local integrator, reducing variance and imposing a useful inductive bias on short EEG windows.Sliding-Window Sequencer
From the condensed time axis (length
T_c
), ATCNet formsn
overlapping windows of widthT_w = T_c - n + 1
(one start per index). Each window produces a sequence(B, F2, T_w)
forwarded to its own attention–TCN branch. This creates parallel encoders over shifted contexts and is key to robustness on nonstationary EEG._AttentionBlock
(small MHA on temporal positions)Operations.
Rearrange to
(B, T_w, F2)
,Normalization
torch.nn.LayerNorm
Custom MultiHeadAttention
_MHA
(num_heads=H
, per-head dimd_h
) + residual add,Dropout
torch.nn.Dropout
Rearrange back to
(B, F2, T_w)
.
Note: Attention is local to a window and purely temporal.
Role. Re-weights evidence across the window, letting the model emphasize informative segments (onsets, bursts) before causal convolutions aggregate history.
_TCNResidualBlock
(causal dilated temporal CNN)Operations.
Two
braindecode.modules.CausalConv1d
layers per block with dilation1, 2, 4, …
Across blocks of torch.nn.ELU + torch.nn.BatchNorm1d + torch.nn.Dropout) + a residual (identity or 1x1 mapping).
The final feature used per window is the last causal step
[..., -1]
(forecast-style).
Role. Efficient long-range temporal integration with stable gradients; the dilated receptive field complements attention’s soft selection.
Aggregation & Classifier
Operations.
Either (a) map each window feature
(B, F2)
to logits viabraindecode.modules.MaxNormLinear
and average across windows (default, matching official code), or - (b) concatenate all window features
(B, n·F2)
and apply a singleMaxNormLinear
. The max-norm constraint regularizes the readout.
Convolutional Details
- Temporal. Temporal structure is learned in three places:
the stem’s wide
(L_t, 1)
conv (learned filter bank),
the refining
(L_r, 1)
conv after pooling (short-term dynamics), and
(3) the TCN’s causal 1-D convolutions with exponentially increasing dilation (long-range dependencies). The minimum sequence length required by the TCN stack is
(K_t - 1)·2^{L-1} + 1
; the implementation auto-scales kernels/pools/windows when inputs are shorter to preserve feasibility.
- Spatial. A depthwise spatial conv spans the full montage (kernel
(1, n_chans)
), producing per-temporal-filter spatial projections (no cross-filter mixing at this step). This mirrors EEGNet’s interpretability: each temporal filter has its own spatial pattern.
- Spatial. A depthwise spatial conv spans the full montage (kernel
Attention / Sequential Modules
Type. Multi-head self-attention with
H
heads and per-head dimd_h
implemented in_MHA
, allowingembed_dim = H·d_h
independent of input and output dims.Shapes.
(B, F2, T_w) → (B, T_w, F2) → (B, F2, T_w)
. Attention operates along the temporal axis within a window; channels/features stay in the embedding dimF2
.Role. Highlights salient temporal positions prior to causal convolution; small attention keeps compute modest while improving context modeling over pooled features.
Additional Mechanisms
Parallel encoders over shifted windows. Improves montage/phase robustness by ensembling nearby contexts rather than committing to a single segmentation.
Max-norm classifier. Enforces weight norm constraints at the readout, a common stabilization trick in EEG decoding.
ViT vs. ATCNet (design choices). Convolutional nonlinear projection rather than linear patchification; attention followed by TCN (not MLP); parallel window encoders rather than stacked encoders.
- Parameters:
n_chans (int) – Number of EEG channels.
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
input_window_seconds (float, optional) – Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a dataset.
sfreq (int, optional) – Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in BCI-IV 2a dataset.
conv_block_n_filters (int) – Number temporal filters in the first convolutional layer of the convolutional block, denoted F1 in figure 2 of the paper [1]. Defaults to 16 as in [1].
conv_block_kernel_length_1 (int) – Length of temporal filters in the first convolutional layer of the convolutional block, denoted Kc in table 1 of the paper [1]. Defaults to 64 as in [1].
conv_block_kernel_length_2 (int) – Length of temporal filters in the last convolutional layer of the convolutional block. Defaults to 16 as in [1].
conv_block_pool_size_1 (int) – Length of first average pooling kernel in the convolutional block. Defaults to 8 as in [1].
conv_block_pool_size_2 (int) – Length of first average pooling kernel in the convolutional block, denoted P2 in table 1 of the paper [1]. Defaults to 7 as in [1].
conv_block_depth_mult (int) – Depth multiplier of depthwise convolution in the convolutional block, denoted D in table 1 of the paper [1]. Defaults to 2 as in [1].
conv_block_dropout (float) – Dropout probability used in the convolution block, denoted pc in table 1 of the paper [1]. Defaults to 0.3 as in [1].
n_windows (int) – Number of sliding windows, denoted n in [1]. Defaults to 5 as in [1].
att_head_dim (int) – Embedding dimension used in each self-attention head, denoted dh in table 1 of the paper [1]. Defaults to 8 as in [1].
att_num_heads (int) – Number of attention heads, denoted H in table 1 of the paper [1]. Defaults to 2 as in [1].
att_drop_prob – The description is missing.
tcn_depth (int) – Depth of Temporal Convolutional Network block (i.e. number of TCN Residual blocks), denoted L in table 1 of the paper [1]. Defaults to 2 as in [1].
tcn_kernel_size (int) – Temporal kernel size used in TCN block, denoted Kt in table 1 of the paper [1]. Defaults to 4 as in [1].
tcn_drop_prob – The description is missing.
tcn_activation (torch.nn.Module) – Nonlinear activation to use. Defaults to nn.ELU().
concat (bool) – When
True
, concatenates each slidding window embedding before feeding it to a fully-connected layer, as done in [1]. WhenFalse
, maps each slidding window to n_outputs logits and average them. Defaults toFalse
contrary to what is reported in [1], but matching what the official code does [2].max_norm_const (float) – Maximum L2-norm constraint imposed on weights of the last fully-connected layer. Defaults to 0.25.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]
. Refer tomne.Info
for more details.n_times (int) – Number of time samples of the input window.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
Inputs substantially shorter than the implied minimum length trigger automatic downscaling of kernels, pools, windows, and TCN kernel size to maintain validity.
The attention–TCN sequence operates per window; the last causal step is used as the window feature, aligning the temporal semantics across windows.
Added in version 1.1:
More detailed documentation of the model.
References
[1] (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25)H. Altaheri, G. Muhammad, M. Alsulaiman (2022). Physics-informed attention temporal convolutional network for EEG-based motor imagery classification. IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
[2]Official EEG-ATCNet implementation (TensorFlow): Altaheri/EEG-ATCNet
Methods
- forward(X)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
X – The description is missing.