braindecode.models.MSVTNet#

class braindecode.models.MSVTNet(n_chans: int | None = None, n_outputs: int | None = None, n_times: int | None = None, input_window_seconds: float | None = None, sfreq: float | None = None, chs_info: ~typing.List[~typing.Dict] | None = None, n_filters_list: ~typing.Tuple[int, ...] = (9, 9, 9, 9), conv1_kernels_size: ~typing.Tuple[int, ...] = (15, 31, 63, 125), conv2_kernel_size: int = 15, depth_multiplier: int = 2, pool1_size: int = 8, pool2_size: int = 7, drop_prob: float = 0.3, num_heads: int = 8, feedforward_ratio: float = 1, drop_prob_trans: float = 0.5, num_layers: int = 2, activation: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ELU'>, return_features: bool = False)[source]#

MSVTNet model from Liu K et al (2024) from [msvt2024].

This model implements a multi-scale convolutional transformer network for EEG signal classification, as described in [msvt2024].

MSVTNet Architecture
Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • n_times (int) – Number of time samples of the input window.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • n_filters_list (List[int], optional) – List of filter numbers for each TSConv block, by default (9, 9, 9, 9).

  • conv1_kernels_size (List[int], optional) – List of kernel sizes for the first convolution in each TSConv block, by default (15, 31, 63, 125).

  • conv2_kernel_size (int, optional) – Kernel size for the second convolution in TSConv blocks, by default 15.

  • depth_multiplier (int, optional) – Depth multiplier for depthwise convolution, by default 2.

  • pool1_size (int, optional) – Pooling size for the first pooling layer in TSConv blocks, by default 8.

  • pool2_size (int, optional) – Pooling size for the second pooling layer in TSConv blocks, by default 7.

  • drop_prob (float, optional) – Dropout probability for convolutional layers, by default 0.3.

  • num_heads (int, optional) – Number of attention heads in the transformer encoder, by default 8.

  • feedforward_ratio (float, optional) – Ratio to compute feedforward dimension in the transformer, by default 1.

  • drop_prob_trans (float, optional) – Dropout probability for the transformer, by default 0.5.

  • num_layers (int, optional) – Number of transformer encoder layers, by default 2.

  • activation (Type[nn.Module], optional) – Activation function class to use, by default nn.ELU.

  • return_features (bool, optional) – Whether to return predictions from branch classifiers, by default False.

Raises:
  • ValueError – If some input signal-related parameters are not specified: and can not be inferred.

  • FutureWarning – If add_log_softmax is True, since LogSoftmax final layer: will be removed in the future.

Notes

This implementation is not guaranteed to be correct, has not been checked by original authors, only reimplemented based on the original code [msvt2024code].

References

[msvt2024] (1,2)

Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding. IEEE Journal of Biomedical an Health Informatics.

[msvt2024code]

Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding. Source Code: SheepTAO/MSVTNet

Methods

forward(x: Tensor) Tensor | Tuple[Tensor, List[Tensor]][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x – The description is missing.