braindecode.models.Labram#
- class braindecode.models.Labram(n_times=None, n_outputs=None, chs_info=None, n_chans=None, sfreq=None, input_window_seconds=None, patch_size=200, learned_patcher=False, embed_dim=200, conv_in_channels=1, conv_out_channels=8, num_layers=12, num_heads=10, mlp_ratio=4.0, qkv_bias=False, qk_norm=<class 'torch.nn.modules.normalization.LayerNorm'>, qk_scale=None, drop_prob=0.0, attn_drop_prob=0.0, drop_path_prob=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_values=0.1, use_abs_pos_emb=True, use_mean_pooling=False, init_scale=0.001, neural_tokenizer=True, attn_head_dim=None, activation=<class 'torch.nn.modules.activation.GELU'>, on_unknown_chs='warn')[source]#
Labram from Jiang, W B et al (2024) [Jiang2024].
Convolution Foundation Model
Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI from [Jiang2024].
This is an adaptation of the code [Code2024] from the Labram model.
The model is transformer architecture with strong inspiration from BEiTv2 [BeiTv2].
The models can be used in two modes:
Neural Tokenizer: Design to get an embedding layers (e.g. classification).
Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
The braindecode’s modification is to allow the model to be used in with an input shape of (batch, n_chans, n_times), if neural tokenizer equals True. The original implementation uses (batch, n_chans, n_patches, patch_size) as input with static segmentation of the input data.
The models have the following sequence of steps:
if neural tokenizer: - SegmentPatch: Segment the input data in patches; - TemporalConv: Apply a temporal convolution to the segmented data; - Residual adding cls, temporal and position embeddings (optional); - WindowsAttentionBlock: Apply a windows attention block to the data; - LayerNorm: Apply layer normalization to the data; - Linear: An head linear layer to transformer the data into classes. else: - PatchEmbed: Apply a patch embedding to the input data; - Residual adding cls, temporal and position embeddings (optional); - WindowsAttentionBlock: Apply a windows attention block to the data; - LayerNorm: Apply layer normalization to the data; - Linear: An head linear layer to transformer the data into classes.
Important
Pre-trained Weights Available
This model has pre-trained weights available on the Hugging Face Hub. You can load them using:
from braindecode.models import Labram # Load pre-trained model from Hugging Face Hub model = Labram.from_pretrained("braindecode/labram-pretrained")
To push your own trained model to the Hub:
# After training your model model.push_to_hub( repo_id="username/my-labram-model", commit_message="Upload trained Labram model" )
Requires installing
braindecode[hug]for Hub integration.Added in version 0.9.
Examples
Load pre-trained weights:
>>> import torch >>> from braindecode.models import Labram >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4) >>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt" >>> state = torch.hub.load_state_dict_from_url(url, progress=True) >>> model.load_state_dict(state)
- Parameters:
patch_size (int) – The size of the patch to be used in the patch embedding.
learned_patcher (bool) – Whether to use a learned patch embedding (via a convolutional layer) or a fixed patch embedding (via rearrangement).
embed_dim (int) – The dimension of the embedding.
conv_in_channels (int) – The number of convolutional input channels.
conv_out_channels (int) – The number of convolutional output channels.
num_layers (int (default=12)) – The number of attention layers of the model.
num_heads (int (default=10)) – The number of attention heads.
mlp_ratio (float (default=4.0)) – The expansion ratio of the mlp layer
qkv_bias (bool (default=False)) – If True, add a learnable bias to the query, key, and value tensors.
qk_norm (Pytorch Normalize layer (default=nn.LayerNorm)) – If not None, apply LayerNorm to the query and key tensors. Default is nn.LayerNorm for better weight transfer from original LaBraM. Set to None to disable Q,K normalization.
qk_scale (float (default=None)) – If not None, use this value as the scale factor. If None, use head_dim**-0.5, where head_dim = dim // num_heads.
drop_prob (float (default=0.0)) – Dropout rate for the attention weights.
attn_drop_prob (float (default=0.0)) – Dropout rate for the attention weights.
drop_path_prob (float (default=0.0)) – Dropout rate for the attention weights used on DropPath.
norm_layer (Pytorch Normalize layer (default=nn.LayerNorm)) – The normalization layer to be used.
init_values (float (default=0.1)) – If not None, use this value to initialize the gamma_1 and gamma_2 parameters for residual scaling. Default is 0.1 for better weight transfer from original LaBraM. Set to None to disable.
use_abs_pos_emb (bool (default=True)) – If True, use absolute position embedding.
use_mean_pooling (bool (default=True)) – If True, use mean pooling.
init_scale (float (default=0.001)) – The initial scale to be used in the parameters of the model.
neural_tokenizer (bool (default=True)) – The model can be used in two modes: Neural Tokenizer or Neural Decoder.
attn_head_dim (bool (default=None)) – The head dimension to be used in the attention layer, to be used only during pre-training.
activation (nn.Module, default=nn.GELU) – Activation function class to apply. Should be a PyTorch activation module class like
nn.ReLUornn.ELU. Default isnn.GELU.on_unknown_chs (Literal["ignore", "warn", "raise"], default="warn") – Determines behavior when channels that are not in LABRAM_CHANNEL_ORDER are passed to the forward method. Options: - “ignore”: Silently ignore and drop unmatched channels, then proceed with matched ones. - “warn”: Issue a warning listing unmatched channels, and drop them. - “raise”: Raise an error and halt execution if any unmatched channels are found. An error is always raised when unknown channels are passed during model initialization (via
chs_info).
References
[Jiang2024] (1,2)Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024, May. Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI. The Twelfth International Conference on Learning Representations, ICLR.
[Code2024]Wei-Bang Jiang, Li-Ming Zhao, Bao-Liang Lu. 2024. Labram Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI. GitHub 935963004/LaBraM (accessed 2024-03-02)
[BeiTv2]Zhiliang Peng, Li Dong, Hangbo Bao, Qixiang Ye, Furu Wei. 2024. BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers. arXiv:2208.06366 [cs.CV]
Methods
- fix_init_weight_and_init_embedding()[source]#
Fix the initial weight and the initial embedding. Initializing with truncated normal distribution.
- forward(x, ch_names=None, return_patch_tokens=False, return_all_tokens=False)[source]#
Forward the input EEG data through the model.
- Parameters:
x (torch.Tensor) – The input data with shape (batch, n_chans, n_times) or (batch, n_chans, n_patches, patch size).
ch_names (list of str or None) – Optional list of channel names corresponding to the input data. This list is used to reorder channels to match LABRAM_CHANNEL_ORDER. If not provided, the channels provided during model initialization (via
chs_info), channels will be used instead If neither is provided, an error will be raised.return_patch_tokens (bool) – Return the patch tokens
return_all_tokens (bool) – Return all the tokens
- Returns:
The output of the model with dimensions (batch, n_outputs)
- Return type:
- forward_features(x, input_chans, return_patch_tokens=False, return_all_tokens=False)[source]#
Forward the features of the model.
- Parameters:
x (torch.Tensor) – The input data with shape (batch, n_chans, n_times).
input_chans (torch.Tensor) – Indices for selecting position embeddings (including the [CLS] token).
return_patch_tokens (bool) – Whether to return the patch tokens.
return_all_tokens (bool) – Whether to return all the tokens.
- Returns:
x – The output of the model.
- Return type: