braindecode.models.ATCNet#
- class braindecode.models.ATCNet(n_chans=None, n_outputs=None, input_window_seconds=4.5, sfreq=250.0, conv_block_n_filters=16, conv_block_kernel_length_1=64, conv_block_kernel_length_2=16, conv_block_pool_size_1=8, conv_block_pool_size_2=7, conv_block_depth_mult=2, conv_block_dropout=0.3, n_windows=5, att_head_dim=8, att_num_heads=2, att_dropout=0.5, tcn_depth=2, tcn_kernel_size=4, tcn_n_filters=32, tcn_dropout=0.3, tcn_activation=ELU(alpha=1.0), concat=False, max_norm_const=0.25, chs_info=None, n_times=None, n_channels=None, n_classes=None, input_size_s=None, add_log_softmax=True)[source]#
ATCNet model from [1]
Pytorch implementation based on official tensorflow code [2].
- Parameters:
n_chans (int) – Number of EEG channels.
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
input_window_seconds (float, optional) – Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a dataset.
sfreq (int, optional) – Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in BCI-IV 2a dataset.
conv_block_n_filters (int) – Number temporal filters in the first convolutional layer of the convolutional block, denoted F1 in figure 2 of the paper [1]. Defaults to 16 as in [1].
conv_block_kernel_length_1 (int) – Length of temporal filters in the first convolutional layer of the convolutional block, denoted Kc in table 1 of the paper [1]. Defaults to 64 as in [1].
conv_block_kernel_length_2 (int) – Length of temporal filters in the last convolutional layer of the convolutional block. Defaults to 16 as in [1].
conv_block_pool_size_1 (int) – Length of first average pooling kernel in the convolutional block. Defaults to 8 as in [1].
conv_block_pool_size_2 (int) – Length of first average pooling kernel in the convolutional block, denoted P2 in table 1 of the paper [1]. Defaults to 7 as in [1].
conv_block_depth_mult (int) – Depth multiplier of depthwise convolution in the convolutional block, denoted D in table 1 of the paper [1]. Defaults to 2 as in [1].
conv_block_dropout (float) – Dropout probability used in the convolution block, denoted pc in table 1 of the paper [1]. Defaults to 0.3 as in [1].
n_windows (int) – Number of sliding windows, denoted n in [1]. Defaults to 5 as in [1].
att_head_dim (int) – Embedding dimension used in each self-attention head, denoted dh in table 1 of the paper [1]. Defaults to 8 as in [1].
att_num_heads (int) – Number of attention heads, denoted H in table 1 of the paper [1]. Defaults to 2 as in [1_.
att_dropout (float) – Dropout probability used in the attention block, denoted pa in table 1 of the paper [1]. Defaults to 0.5 as in [1].
tcn_depth (int) – Depth of Temporal Convolutional Network block (i.e. number of TCN Residual blocks), denoted L in table 1 of the paper [1]. Defaults to 2 as in [1].
tcn_kernel_size (int) – Temporal kernel size used in TCN block, denoted Kt in table 1 of the paper [1]. Defaults to 4 as in [1].
tcn_n_filters (int) – Number of filters used in TCN convolutional layers (Ft). Defaults to 32 as in [1].
tcn_dropout (float) – Dropout probability used in the TCN block, denoted pt in table 1 of the paper [1]. Defaults to 0.3 as in [1].
tcn_activation (torch.nn.Module) – Nonlinear activation to use. Defaults to nn.ELU().
concat (bool) – When
True
, concatenates each slidding window embedding before feeding it to a fully-connected layer, as done in [1]. WhenFalse
, maps each slidding window to n_outputs logits and average them. Defaults toFalse
contrary to what is reported in [1], but matching what the official code does [2].max_norm_const (float) – Maximum L2-norm constraint imposed on weights of the last fully-connected layer. Defaults to 0.25.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]
. Refer tomne.Info
for more details.n_times (int) – Number of time samples of the input window.
n_channels – Alias for n_chans.
n_classes – Alias for n_outputs.
input_size_s – Alias for input_window_seconds.
add_log_softmax (bool) – Whether to use log-softmax non-linearity as the output function. LogSoftmax final layer will be removed in the future. Please adjust your loss function accordingly (e.g. CrossEntropyLoss)! Check the documentation of the torch.nn loss functions: https://pytorch.org/docs/stable/nn.html#loss-functions.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
FutureWarning – If add_log_softmax is True, since LogSoftmax final layer: will be removed in the future.
Notes
If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.
References
[1] (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29)H. Altaheri, G. Muhammad and M. Alsulaiman, “Physics-informed attention temporal convolutional network for EEG-based motor imagery classification,” in IEEE Transactions on Industrial Informatics, 2022, doi: 10.1109/TII.2022.3197419.
Methods
- forward(X)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
X – The description is missing.