braindecode.models.MSVTNet#

class braindecode.models.MSVTNet(n_chans=None, n_outputs=None, n_times=None, input_window_seconds=None, sfreq=None, chs_info=None, n_filters_list=(9, 9, 9, 9), conv1_kernels_size=(15, 31, 63, 125), conv2_kernel_size=15, depth_multiplier=2, pool1_size=8, pool2_size=7, drop_prob=0.3, num_heads=8, ffn_expansion_factor=1, att_drop_prob=0.5, num_layers=2, activation=<class 'torch.nn.modules.activation.ELU'>, return_features=False)[source]#

MSVTNet model from Liu K et al (2024) from [msvt2024].

Convolution Recurrent Attention/Transformer

This model implements a multi-scale convolutional transformer network for EEG signal classification, as described in [msvt2024].

MSVTNet Architecture
Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • n_times (int) – Number of time samples of the input window.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • n_filters_list (tuple[int, ...]) – List of filter numbers for each TSConv block, by default (9, 9, 9, 9).

  • conv1_kernels_size (tuple[int, ...]) – List of kernel sizes for the first convolution in each TSConv block, by default (15, 31, 63, 125).

  • conv2_kernel_size (int) – Kernel size for the second convolution in TSConv blocks, by default 15.

  • depth_multiplier (int) – Depth multiplier for depthwise convolution, by default 2.

  • pool1_size (int) – Pooling size for the first pooling layer in TSConv blocks, by default 8.

  • pool2_size (int) – Pooling size for the second pooling layer in TSConv blocks, by default 7.

  • drop_prob (float) – Dropout probability for convolutional layers, by default 0.3.

  • num_heads (int) – Number of attention heads in the transformer encoder, by default 8.

  • ffn_expansion_factor (float) – Ratio to compute feedforward dimension in the transformer, by default 1.

  • att_drop_prob (float) – Dropout probability for the transformer, by default 0.5.

  • num_layers (int) – Number of transformer encoder layers, by default 2.

  • activation (Type[Module]) – Activation function class to use, by default nn.ELU.

  • return_features (bool) – Whether to return predictions from branch classifiers, by default False.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

This implementation is not guaranteed to be correct, has not been checked by original authors, only reimplemented based on the original code [msvt2024code].

References

[msvt2024] (1,2)

Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding. IEEE Journal of Biomedical an Health Informatics.

[msvt2024code]

Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding. Source Code: https://github.com/SheepTAO/MSVTNet

Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:

pip install braindecode[hub]

Pushing a model to the Hub:

from braindecode.models import MSVTNet

# Train your model
model = MSVTNet(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...

# Push to the Hub
model.push_to_hub(
    repo_id="username/my-msvtnet-model",
    commit_message="Initial model upload",
)

Loading a model from the Hub:

from braindecode.models import MSVTNet

# Load pretrained model
model = MSVTNet.from_pretrained("username/my-msvtnet-model")

# Load with a different number of outputs (head is rebuilt automatically)
model = MSVTNet.from_pretrained("username/my-msvtnet-model", n_outputs=4)

Extracting features and replacing the head:

import torch

x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]

# Replace the classification head
model.reset_head(n_outputs=10)

Saving and restoring full configuration:

import json

config = model.get_config()            # all __init__ params
with open("config.json", "w") as f:
    json.dump(config, f)

model2 = MSVTNet.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See Loading and Adapting Pretrained Foundation Models for a complete tutorial.

Methods

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor) – The description is missing.

Return type:

Tensor