braindecode.models.ATCNet#
- class braindecode.models.ATCNet(n_chans=None, n_outputs=None, input_window_seconds=None, sfreq=250, conv_block_n_filters=16, conv_block_kernel_length_1=64, conv_block_kernel_length_2=16, conv_block_pool_size_1=8, conv_block_pool_size_2=7, conv_block_depth_mult=2, conv_block_dropout=0.3, n_windows=5, att_head_dim=8, att_num_heads=2, att_drop_prob=0.5, tcn_depth=2, tcn_kernel_size=4, tcn_n_filters=32, tcn_drop_prob=0.3, tcn_activation: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ELU'>, concat=False, max_norm_const=0.25, chs_info=None, n_times=None)[source]#
ATCNet model from Altaheri et al. (2022) [1]
Pytorch implementation based on official tensorflow code [2].
- Parameters:
n_chans (int) – Number of EEG channels.
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
input_window_seconds (float, optional) – Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a dataset.
sfreq (int, optional) – Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in BCI-IV 2a dataset.
conv_block_n_filters (int) – Number temporal filters in the first convolutional layer of the convolutional block, denoted F1 in figure 2 of the paper [1]. Defaults to 16 as in [1].
conv_block_kernel_length_1 (int) – Length of temporal filters in the first convolutional layer of the convolutional block, denoted Kc in table 1 of the paper [1]. Defaults to 64 as in [1].
conv_block_kernel_length_2 (int) – Length of temporal filters in the last convolutional layer of the convolutional block. Defaults to 16 as in [1].
conv_block_pool_size_1 (int) – Length of first average pooling kernel in the convolutional block. Defaults to 8 as in [1].
conv_block_pool_size_2 (int) – Length of first average pooling kernel in the convolutional block, denoted P2 in table 1 of the paper [1]. Defaults to 7 as in [1].
conv_block_depth_mult (int) – Depth multiplier of depthwise convolution in the convolutional block, denoted D in table 1 of the paper [1]. Defaults to 2 as in [1].
conv_block_dropout (float) – Dropout probability used in the convolution block, denoted pc in table 1 of the paper [1]. Defaults to 0.3 as in [1].
n_windows (int) – Number of sliding windows, denoted n in [1]. Defaults to 5 as in [1].
att_head_dim (int) – Embedding dimension used in each self-attention head, denoted dh in table 1 of the paper [1]. Defaults to 8 as in [1].
att_num_heads (int) – Number of attention heads, denoted H in table 1 of the paper [1]. Defaults to 2 as in [1].
att_drop_prob – The description is missing.
tcn_depth (int) – Depth of Temporal Convolutional Network block (i.e. number of TCN Residual blocks), denoted L in table 1 of the paper [1]. Defaults to 2 as in [1].
tcn_kernel_size (int) – Temporal kernel size used in TCN block, denoted Kt in table 1 of the paper [1]. Defaults to 4 as in [1].
tcn_n_filters (int) – Number of filters used in TCN convolutional layers (Ft). Defaults to 32 as in [1].
tcn_drop_prob – The description is missing.
tcn_activation (torch.nn.Module) – Nonlinear activation to use. Defaults to nn.ELU().
concat (bool) – When
True
, concatenates each slidding window embedding before feeding it to a fully-connected layer, as done in [1]. WhenFalse
, maps each slidding window to n_outputs logits and average them. Defaults toFalse
contrary to what is reported in [1], but matching what the official code does [2].max_norm_const (float) – Maximum L2-norm constraint imposed on weights of the last fully-connected layer. Defaults to 0.25.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]
. Refer tomne.Info
for more details.n_times (int) – Number of time samples of the input window.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
FutureWarning – If add_log_softmax is True, since LogSoftmax final layer: will be removed in the future.
Notes
If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.
References
[1] (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26)H. Altaheri, G. Muhammad and M. Alsulaiman, Physics-informed attention temporal convolutional network for EEG-based motor imagery classification in IEEE Transactions on Industrial Informatics, 2022, doi: 10.1109/TII.2022.3197419.
Methods
- forward(X)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
X – The description is missing.