braindecode.models.ContraWR#

class braindecode.models.ContraWR(n_chans=None, n_outputs=None, sfreq=None, emb_size=256, res_channels=[32, 64, 128], steps=20, activation=<class 'torch.nn.modules.activation.ELU'>, drop_prob=0.5, stride_res=2, kernel_size_res=3, padding_res=1, chs_info=None, n_times=None, input_window_seconds=None)[source]#

Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021].

Convolution

This model is a convolutional neural network that uses a spectral representation with a series of convolutional layers and residual blocks. The model is designed to learn a representation of the EEG signal that can be used for sleep staging.

Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • emb_size (int) – Embedding size for the final layer, by default 256.

  • res_channels (list[int]) – Number of channels for each residual block, by default [32, 64, 128].

  • steps (int, optional) – Number of steps to take the frequency decomposition hop_length parameters by default 20.

  • activation (type[Module]) – Activation function class to apply. Should be a PyTorch activation module class like nn.ReLU or nn.ELU. Default is nn.ELU.

  • drop_prob (float) – The dropout rate for regularization. Values should be between 0 and 1.

  • versionadded: (..) – 0.9:

  • stride_res (int) – The description is missing.

  • kernel_size_res (int) – The description is missing.

  • padding_res (int) – The description is missing.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • n_times (int) – Number of time samples of the input window.

  • input_window_seconds (float) – Length of the input window in seconds.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

This implementation is not guaranteed to be correct, has not been checked by original authors. The modifications are minimal and the model is expected to work as intended. the original code from [Code2023].

References

[Yang2021]

Yang, C., Xiao, C., Westover, M. B., & Sun, J. (2023). Self-supervised electroencephalogram representation learning for automatic sleep staging: model development and evaluation study. JMIR AI, 2(1), e46769.

[Code2023]

Yang, C., Westover, M.B. and Sun, J., 2023. BIOT Biosignal Transformer for Cross-data Learning in the Wild. GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)

Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:

pip install braindecode[hub]

Pushing a model to the Hub:

from braindecode.models import ContraWR

# Train your model
model = ContraWR(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...

# Push to the Hub
model.push_to_hub(
    repo_id="username/my-contrawr-model",
    commit_message="Initial model upload",
)

Loading a model from the Hub:

from braindecode.models import ContraWR

# Load pretrained model
model = ContraWR.from_pretrained("username/my-contrawr-model")

# Load with a different number of outputs (head is rebuilt automatically)
model = ContraWR.from_pretrained("username/my-contrawr-model", n_outputs=4)

Extracting features and replacing the head:

import torch

x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]

# Replace the classification head
model.reset_head(n_outputs=10)

Saving and restoring full configuration:

import json

config = model.get_config()            # all __init__ params
with open("config.json", "w") as f:
    json.dump(config, f)

model2 = ContraWR.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See Loading and Adapting Pretrained Foundation Models for a complete tutorial.

Methods

forward(X)[source]#

Forward pass.

Parameters:

X (Tensor) – Input tensor of shape (batch_size, n_channels, n_times).

Returns:

Output tensor of shape (batch_size, n_outputs).

Return type:

Tensor