braindecode.models.BIOT#
- class braindecode.models.BIOT(embed_dim=256, num_heads=8, num_layers=4, sfreq=200, hop_length=100, return_feature=False, n_outputs=None, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, activation=<class 'torch.nn.modules.activation.ELU'>, drop_prob=0.5, max_seq_len=1024, att_drop_prob=0.2, att_layer_drop_prob=0.2)[source]#
BIOT from Yang et al (2023) [Yang2023]
Foundation Model
BIOT: Cross-data Biosignal Learning in the Wild.
BIOT is a foundation model for biosignal classification. It is a wrapper around the BIOTEncoder and ClassificationHead modules.
It is designed for N-dimensional biosignal data such as EEG, ECG, etc. The method was proposed by Yang et al. [Yang2023] and the code is available at [Code2023]
The model is trained with a contrastive loss on large EEG datasets TUH Abnormal EEG Corpus with 400K samples and Sleep Heart Health Study 5M. Here, we only provide the model architecture, not the pre-trained weights or contrastive loss training.
The architecture is based on the LinearAttentionTransformer and PatchFrequencyEmbedding modules. The BIOTEncoder is a transformer that takes the input data and outputs a fixed-size representation of the input data. More details are present in the BIOTEncoder class.
The ClassificationHead is an ELU activation layer, followed by a simple linear layer that takes the output of the BIOTEncoder and outputs the classification probabilities.
Important
Pre-trained Weights Available
This model has pre-trained weights available on the Hugging Face Hub. You can load them using:
from braindecode.models import BIOT # Load the original pre-trained model from Hugging Face Hub # For 16-channel models: model = BIOT.from_pretrained("braindecode/biot-pretrained-prest-16chs") # For 18-channel models: model = BIOT.from_pretrained("braindecode/biot-pretrained-shhs-prest-18chs") model = BIOT.from_pretrained("braindecode/biot-pretrained-six-datasets-18chs")
To push your own trained model to the Hub:
# After training your model model.push_to_hub( repo_id="username/my-biot-model", commit_message="Upload trained BIOT model" )
Requires installing
braindecode[hug]for Hub integration.Added in version 0.9.
- Parameters:
embed_dim (int, optional) – The size of the embedding layer, by default 256
num_heads (int, optional) – The number of attention heads, by default 8
num_layers (int, optional) – The number of transformer layers, by default 4
sfreq (int, optional) – The sfreq parameter for the encoder. The default is 200
hop_length (int, optional) – The hop length for the torch.stft transformation in the encoder. The default is 100.
return_feature (bool, optional) – Changing the output for the neural network. Default is single tensor when return_feature is True, return embedding space too. Default is False.
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
n_chans (int) – Number of EEG channels.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]. Refer tomne.Infofor more details.n_times (int) – Number of time samples of the input window.
input_window_seconds (float) – Length of the input window in seconds.
activation (
type[Module]) – Activation function class to apply. Should be a PyTorch activation module class likenn.ReLUornn.ELU. Default isnn.ELU.drop_prob (
float) – The description is missing.max_seq_len (
int) – The description is missing.att_drop_prob – The description is missing.
att_layer_drop_prob – The description is missing.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.
References
[Yang2023] (1,2)Yang, C., Westover, M.B. and Sun, J., 2023, November. BIOT: Biosignal Transformer for Cross-data Learning in the Wild. In Thirty-seventh Conference on Neural Information Processing Systems, NeurIPS.
[Code2023]Yang, C., Westover, M.B. and Sun, J., 2023. BIOT Biosignal Transformer for Cross-data Learning in the Wild. GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
Hugging Face Hub integration
When the optional
huggingface_hubpackage is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:pip install braindecode[hub]
Pushing a model to the Hub:
from braindecode.models import BIOT # Train your model model = BIOT(n_chans=22, n_outputs=4, n_times=1000) # ... training code ... # Push to the Hub model.push_to_hub( repo_id="username/my-biot-model", commit_message="Initial model upload", )
Loading a model from the Hub:
from braindecode.models import BIOT # Load pretrained model model = BIOT.from_pretrained("username/my-biot-model") # Load with a different number of outputs (head is rebuilt automatically) model = BIOT.from_pretrained("username/my-biot-model", n_outputs=4)
Extracting features and replacing the head:
import torch x = torch.randn(1, model.n_chans, model.n_times) # Extract encoder features (consistent dict across all models) out = model(x, return_features=True) features = out["features"] # Replace the classification head model.reset_head(n_outputs=10)
Saving and restoring full configuration:
import json config = model.get_config() # all __init__ params with open("config.json", "w") as f: json.dump(config, f) model2 = BIOT.from_config(config) # reconstruct (no weights)
All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.
See Loading and Adapting Pretrained Foundation Models for a complete tutorial.
Methods
- forward(x, return_features=False)[source]#
Pass the input through the BIOT encoder, and then through the classification head.
- Parameters:
x (Tensor) – (batch_size, n_channels, n_times)
return_features (bool) – If True, return a dict with
"features"and"cls_token"instead of the classification output.
- Returns:
Default:
torch.Tensorof shape(batch_size, n_outputs). Ifreturn_features=True:dictwith"features"(batch_size, emb_size)and"cls_token"(None). If legacyreturn_feature=True(init param):(out, emb)tuple (ignored whenreturn_features=True).- Return type:
torch.Tensor or tuple or dict
- reset_head(n_outputs)[source]#
Replace the classification head for a new number of outputs.
This is called automatically by
from_pretrained()when the user passes ann_outputsthat differs from the saved config. Override in subclasses that need a model-specific head structure.- Parameters:
n_outputs (int) – New number of output classes.
Examples
>>> from braindecode.models import BENDR >>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4) >>> model.reset_head(10) >>> model.n_outputs 10
Added in version 1.4.