braindecode.models.EEGInceptionMI#

class braindecode.models.EEGInceptionMI(n_chans=None, n_outputs=None, input_window_seconds=None, sfreq=250, n_convs=5, n_filters=48, kernel_unit_s=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, chs_info=None, n_times=None)[source]#

EEG Inception for Motor Imagery, as proposed in Zhang et al. (2021) [1]

Convolution

EEGInceptionMI Architecture

The model is strongly based on the original InceptionNet for computer vision. The main goal is to extract features in parallel with different scales. The network has two blocks made of 3 inception modules with a skip connection.

The model is fully described in [1].

Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • input_window_seconds (float, optional) – Size of the input, in seconds. Set to 4.5 s as in [1] for dataset BCI IV 2a.

  • sfreq (float, optional) – EEG sampling frequency in Hz. Defaults to 250 Hz as in [1] for dataset BCI IV 2a.

  • n_convs (int) – Number of convolution per inception wide branching. Defaults to 5 as in [1] for dataset BCI IV 2a.

  • n_filters (int) – Number of convolutional filters for all layers of this type. Set to 48 as in [1] for dataset BCI IV 2a.

  • kernel_unit_s (float) – Size in seconds of the basic 1D convolutional kernel used in inception modules. Each convolutional layer in such modules have kernels of increasing size, odd multiples of this value (e.g. 0.1, 0.3, 0.5, 0.7, 0.9 here for n_convs=5). Defaults to 0.1 s.

  • activation (type[Module]) – Activation function. Defaults to ReLU activation.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • n_times (int) – Number of time samples of the input window.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

This implementation is not guaranteed to be correct, has not been checked by original authors, only reimplemented bosed on the paper [1].

References

[1] (1,2,3,4,5,6,7)

Zhang, C., Kim, Y. K., & Eskandarian, A. (2021). EEG-inception: an accurate and robust end-to-end neural network for EEG-based motor imagery classification. Journal of Neural Engineering, 18(4), 046014.

Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:

pip install braindecode[hub]

Pushing a model to the Hub:

from braindecode.models import EEGInceptionMI

# Train your model
model = EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...

# Push to the Hub
model.push_to_hub(
    repo_id="username/my-eeginceptionmi-model",
    commit_message="Initial model upload",
)

Loading a model from the Hub:

from braindecode.models import EEGInceptionMI

# Load pretrained model
model = EEGInceptionMI.from_pretrained("username/my-eeginceptionmi-model")

# Load with a different number of outputs (head is rebuilt automatically)
model = EEGInceptionMI.from_pretrained("username/my-eeginceptionmi-model", n_outputs=4)

Extracting features and replacing the head:

import torch

x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]

# Replace the classification head
model.reset_head(n_outputs=10)

Saving and restoring full configuration:

import json

config = model.get_config()            # all __init__ params
with open("config.json", "w") as f:
    json.dump(config, f)

model2 = EEGInceptionMI.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See Loading and Adapting Pretrained Foundation Models for a complete tutorial.

Methods

forward(X)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

X (Tensor) – The description is missing.

Return type:

Tensor