braindecode.models.ContraWR#
- class braindecode.models.ContraWR(n_chans=None, n_outputs=None, sfreq=None, emb_size=256, res_channels=[32, 64, 128], steps=20, activation=<class 'torch.nn.modules.activation.ELU'>, drop_prob=0.5, stride_res=2, kernel_size_res=3, padding_res=1, chs_info=None, n_times=None, input_window_seconds=None)[source]#
Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021].
Convolution
This model is a convolutional neural network that uses a spectral representation with a series of convolutional layers and residual blocks. The model is designed to learn a representation of the EEG signal that can be used for sleep staging.
- Parameters:
n_chans (int) – Number of EEG channels.
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
sfreq (float) – Sampling frequency of the EEG recordings.
emb_size (
int) – Embedding size for the final layer, by default 256.res_channels (
list[int]) – Number of channels for each residual block, by default [32, 64, 128].steps (int, optional) – Number of steps to take the frequency decomposition hop_length parameters by default 20.
activation (
type[Module]) – Activation function class to apply. Should be a PyTorch activation module class likenn.ReLUornn.ELU. Default isnn.ELU.drop_prob (
float) – The dropout rate for regularization. Values should be between 0 and 1.versionadded: (..) – 0.9:
stride_res (
int) – The description is missing.kernel_size_res (
int) – The description is missing.padding_res (
int) – The description is missing.chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]. Refer tomne.Infofor more details.n_times (int) – Number of time samples of the input window.
input_window_seconds (float) – Length of the input window in seconds.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
This implementation is not guaranteed to be correct, has not been checked by original authors. The modifications are minimal and the model is expected to work as intended. the original code from [Code2023].
References
[Yang2021]Yang, C., Xiao, C., Westover, M. B., & Sun, J. (2023). Self-supervised electroencephalogram representation learning for automatic sleep staging: model development and evaluation study. JMIR AI, 2(1), e46769.
[Code2023]Yang, C., Westover, M.B. and Sun, J., 2023. BIOT Biosignal Transformer for Cross-data Learning in the Wild. GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
Hugging Face Hub integration
When the optional
huggingface_hubpackage is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:pip install braindecode[hub]
Pushing a model to the Hub:
from braindecode.models import ContraWR # Train your model model = ContraWR(n_chans=22, n_outputs=4, n_times=1000) # ... training code ... # Push to the Hub model.push_to_hub( repo_id="username/my-contrawr-model", commit_message="Initial model upload", )
Loading a model from the Hub:
from braindecode.models import ContraWR # Load pretrained model model = ContraWR.from_pretrained("username/my-contrawr-model") # Load with a different number of outputs (head is rebuilt automatically) model = ContraWR.from_pretrained("username/my-contrawr-model", n_outputs=4)
Extracting features and replacing the head:
import torch x = torch.randn(1, model.n_chans, model.n_times) # Extract encoder features (consistent dict across all models) out = model(x, return_features=True) features = out["features"] # Replace the classification head model.reset_head(n_outputs=10)
Saving and restoring full configuration:
import json config = model.get_config() # all __init__ params with open("config.json", "w") as f: json.dump(config, f) model2 = ContraWR.from_config(config) # reconstruct (no weights)
All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.
See Loading and Adapting Pretrained Foundation Models for a complete tutorial.
Methods