braindecode.models.ATCNet#
- class braindecode.models.ATCNet(n_chans=None, n_outputs=None, input_window_seconds=None, sfreq=250.0, conv_block_n_filters=16, conv_block_kernel_length_1=64, conv_block_kernel_length_2=16, conv_block_pool_size_1=8, conv_block_pool_size_2=7, conv_block_depth_mult=2, conv_block_dropout=0.3, n_windows=5, att_head_dim=8, att_num_heads=2, att_drop_prob=0.5, tcn_depth=2, tcn_kernel_size=4, tcn_drop_prob=0.3, tcn_activation=<class 'torch.nn.modules.activation.ELU'>, concat=False, max_norm_const=0.25, chs_info=None, n_times=None)[source]#
 ATCNet from Altaheri et al. (2022) [1].
Convolution Recurrent Small Attention
Architectural Overview
ATCNet is a convolution-first architecture augmented with a lightweight attention–TCN sequence module. The end-to-end flow is:
(i)
_ConvBlocklearns temporal filter-banks and spatial projections (EEGNet-style), downsampling time to a compact feature map;Sliding Windows carve overlapping temporal windows from this map;
(iii) for each window,
_AttentionBlockapplies small multi-head self-attention over time, followed by a_TCNResidualBlockstack (causal, dilated);(iv) window-level features are aggregated (mean of window logits or concatenation) and mapped via a max-norm–constrained linear layer.
Relative to ViT, ATCNet replaces linear patch projection with learned temporal–spatial convolutions; it processes parallel window encoders (attention→TCN) instead of a deep stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.
Macro Components
_ConvBlock(Shallow conv stem → feature map)Operations.
- Temporal conv (
torch.nn.Conv2d) with kernel(L_t, 1)builds a FIR-like filter bank (
F1maps).
- Temporal conv (
 Depthwise spatial conv (
torch.nn.Conv2d,groups=F1) with kernel(1, n_chans)learns per-filter spatial projections (akin to EEGNet’s CSP-like step).BN → ELU → AvgPool → Dropout to stabilize and condense activations.
Refining temporal conv (
torch.nn.Conv2d) with kernel(L_r, 1)+ BN → ELU → AvgPool → Dropout.
The output shape is
(B, F2, T_c, 1)withF2 = F1·DandT_c = T/(P1·P2). Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific topographies. Pooling acts as a local integrator, reducing variance and imposing a useful inductive bias on short EEG windows.Sliding-Window Sequencer
From the condensed time axis (length
T_c), ATCNet formsnoverlapping windows of widthT_w = T_c - n + 1(one start per index). Each window produces a sequence(B, F2, T_w)forwarded to its own attention-TCN branch. This creates parallel encoders over shifted contexts and is key to robustness on nonstationary EEG._AttentionBlock(small MHA on temporal positions)Attention here is local to a window and purely temporal.
Operations.
Rearrange to
(B, T_w, F2),Normalization
torch.nn.LayerNormCustom MultiHeadAttention
_MHA(num_heads=H, per-head dimd_h) + residual add,Dropout
torch.nn.DropoutRearrange back to
(B, F2, T_w).
Role. Re-weights evidence across the window, letting the model emphasize informative segments (onsets, bursts) before causal convolutions aggregate history.
_TCNResidualBlock(causal dilated temporal CNN)Operations.
Two
braindecode.modules.CausalConv1dlayers per block with dilation1, 2, 4, …Across blocks of torch.nn.ELU + torch.nn.BatchNorm1d + torch.nn.Dropout) + a residual (identity or 1x1 mapping).
The final feature used per window is the last causal step
[..., -1](forecast-style).
Role. Efficient long-range temporal integration with stable gradients; the dilated receptive field complements attention’s soft selection.
Aggregation & Classifier
Operations.
Either (a) map each window feature
(B, F2)to logits viabraindecode.modules.MaxNormLinear
and average across windows (default, matching official code), or - (b) concatenate all window features
(B, n·F2)and apply a singleMaxNormLinear. The max-norm constraint regularizes the readout.
Convolutional Details
- Temporal. Temporal structure is learned in three places:
 the stem’s wide
(L_t, 1)conv (learned filter bank),
the refining
(L_r, 1)conv after pooling (short-term dynamics), and
(3) the TCN’s causal 1-D convolutions with exponentially increasing dilation (long-range dependencies). The minimum sequence length required by the TCN stack is
(K_t - 1)·2^{L-1} + 1; the implementation auto-scales kernels/pools/windows when inputs are shorter to preserve feasibility.
- Spatial. A depthwise spatial conv spans the full montage (kernel 
(1, n_chans)), producing per-temporal-filter spatial projections (no cross-filter mixing at this step). This mirrors EEGNet’s interpretability: each temporal filter has its own spatial pattern.
- Spatial. A depthwise spatial conv spans the full montage (kernel 
 
Attention / Sequential Modules
Type. Multi-head self-attention with
Hheads and per-head dimd_himplemented in_MHA, allowingembed_dim = H·d_hindependent of input and output dims.Shapes.
(B, F2, T_w) → (B, T_w, F2) → (B, F2, T_w). Attention operates along the temporal axis within a window; channels/features stay in the embedding dimF2.Role. Highlights salient temporal positions prior to causal convolution; small attention keeps compute modest while improving context modeling over pooled features.
Additional Mechanisms
Parallel encoders over shifted windows. Improves montage/phase robustness by ensembling nearby contexts rather than committing to a single segmentation.
Max-norm classifier. Enforces weight norm constraints at the readout, a common stabilization trick in EEG decoding.
ViT vs. ATCNet (design choices). Convolutional nonlinear projection rather than linear patchification; attention followed by TCN (not MLP); parallel window encoders rather than stacked encoders.
Usage and Configuration
conv_block_n_filters (F1),conv_block_depth_mult (D)→ capacity of the stem (withF2 = F1·Dfeeding attention/TCN), dimensions aligned toF2, likeEEGNet.Pool sizes
P1,P2trade temporal resolution for stability/compute; they setT_c = T/(P1·P2)and thus window widthT_w.n_windowscontrols the ensemble over shifts (compute ∝ windows).att_num_heads,att_head_dimset attention capacity; keepH·d_h ≈ F2.tcn_depth,tcn_kernel_sizegovern receptive field; larger values demand longer inputs (see minimum length above). The implementation warns and rescales kernels/pools/windows if inputs are too short.Aggregation choice.
concat=False(default, average of per-window logits) matches the official code;concat=Truemirrors the paper’s concatenation variant.
- Parameters:
 n_chans (int) – Number of EEG channels.
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
input_window_seconds (float, optional) – Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a dataset.
sfreq (int, optional) – Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in BCI-IV 2a dataset.
conv_block_n_filters (int) – Number temporal filters in the first convolutional layer of the convolutional block, denoted F1 in figure 2 of the paper [1]. Defaults to 16 as in [1].
conv_block_kernel_length_1 (int) – Length of temporal filters in the first convolutional layer of the convolutional block, denoted Kc in table 1 of the paper [1]. Defaults to 64 as in [1].
conv_block_kernel_length_2 (int) – Length of temporal filters in the last convolutional layer of the convolutional block. Defaults to 16 as in [1].
conv_block_pool_size_1 (int) – Length of first average pooling kernel in the convolutional block. Defaults to 8 as in [1].
conv_block_pool_size_2 (int) – Length of first average pooling kernel in the convolutional block, denoted P2 in table 1 of the paper [1]. Defaults to 7 as in [1].
conv_block_depth_mult (int) – Depth multiplier of depthwise convolution in the convolutional block, denoted D in table 1 of the paper [1]. Defaults to 2 as in [1].
conv_block_dropout (float) – Dropout probability used in the convolution block, denoted pc in table 1 of the paper [1]. Defaults to 0.3 as in [1].
n_windows (int) – Number of sliding windows, denoted n in [1]. Defaults to 5 as in [1].
att_head_dim (int) – Embedding dimension used in each self-attention head, denoted dh in table 1 of the paper [1]. Defaults to 8 as in [1].
att_num_heads (int) – Number of attention heads, denoted H in table 1 of the paper [1]. Defaults to 2 as in [1].
att_drop_prob – The description is missing.
tcn_depth (int) – Depth of Temporal Convolutional Network block (i.e. number of TCN Residual blocks), denoted L in table 1 of the paper [1]. Defaults to 2 as in [1].
tcn_kernel_size (int) – Temporal kernel size used in TCN block, denoted Kt in table 1 of the paper [1]. Defaults to 4 as in [1].
tcn_drop_prob – The description is missing.
tcn_activation (torch.nn.Module) – Nonlinear activation to use. Defaults to nn.ELU().
concat (bool) – When
True, concatenates each slidding window embedding before feeding it to a fully-connected layer, as done in [1]. WhenFalse, maps each slidding window to n_outputs logits and average them. Defaults toFalsecontrary to what is reported in [1], but matching what the official code does [2].max_norm_const (float) – Maximum L2-norm constraint imposed on weights of the last fully-connected layer. Defaults to 0.25.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]. Refer tomne.Infofor more details.n_times (int) – Number of time samples of the input window.
- Raises:
 ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
Inputs substantially shorter than the implied minimum length trigger automatic downscaling of kernels, pools, windows, and TCN kernel size to maintain validity.
The attention–TCN sequence operates per window; the last causal step is used as the window feature, aligning the temporal semantics across windows.
Added in version 1.1:
More detailed documentation of the model.
References
[1] (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25)H. Altaheri, G. Muhammad, M. Alsulaiman (2022). Physics-informed attention temporal convolutional network for EEG-based motor imagery classification. IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
[2]Official EEG-ATCNet implementation (TensorFlow): Altaheri/EEG-ATCNet
Methods
- forward(X)[source]#
 Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.