braindecode.models.LUNA#

class braindecode.models.LUNA(n_outputs=None, n_chans=None, n_times=None, sfreq=None, chs_info=None, input_window_seconds=None, patch_size=40, num_queries=4, embed_dim=64, depth=8, num_heads=2, mlp_ratio=4.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, drop_path=0.0, drop_prob_chan=0.0, attn_drop=0.0, activation=<class 'torch.nn.modules.activation.GELU'>)[source]#

LUNA from Döner et al. [LUNA].

Convolution Foundation Model Channel

LUNA Architecture.

LUNA is a topology-invariant EEG model that processes signals from varying numbers of channels using a channel-unification mechanism with learned queries.

The architecture consists of: 1. Patch Feature Extraction (temporal CNN + FFT-based features) 2. Channel-Unification Module (cross-attention with learned queries) 3. Patch-wise Temporal Encoder (RoPE-based transformer) 4. Decoder Heads (classification or reconstruction)

Important

Pre-trained Weights Available

This model has pre-trained weights available on the Hugging Face Hub at thorir/LUNA.

Available model variants:

  • LUNA_base.safetensors - Base model (embed_dim=64, num_queries=4, depth=8)

  • LUNA_large.safetensors - Large model (embed_dim=96, num_queries=6, depth=10)

  • LUNA_huge.safetensors - Huge model (embed_dim=128, num_queries=8, depth=24)

Example loading for fine-tuning:

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from braindecode.models import LUNA

# Download pre-trained weights
model_path = hf_hub_download(
    repo_id="thorir/LUNA",
    filename="LUNA_base.safetensors",
)

# Create model for classification (fine-tuning)
model = LUNA(
    n_outputs=2,  # Number of classes for your task
    n_chans=22,
    n_times=1000,
    embed_dim=64,
    num_queries=4,
    depth=8,
)

# Load pre-trained encoder weights
state_dict = load_file(model_path)
# Apply key mapping for pretrained weights
mapping = model.mapping.copy()
mapping["cross_attn.temparature"] = "cross_attn.temperature"
mapped_state_dict = {mapping.get(k, k): v for k, v in state_dict.items()}
model.load_state_dict(mapped_state_dict, strict=False)

To push your own trained model to the Hub:

# After training your model
model.push_to_hub(
    repo_id="username/my-luna-model", commit_message="Upload trained LUNA model"
)

Requires installing braindecode[hug] for Hub integration.

Parameters:
  • patch_size (int) – Number of time samples per patch. Default: 40.

  • num_queries (int) – Number of learned queries for channel unification. Paper uses: 4 (Base), 6 (Large), 8 (Huge). Default: 4.

  • embed_dim (int) – Embedding dimension for patch features. Paper uses: 64 (Base), 96 (Large), 128 (Huge). Default: 64.

  • depth (int) – Number of transformer encoder blocks. Paper uses: 8 (Base), 10 (Large), 24 (Huge). Default: 8.

  • num_heads (int) – Number of attention heads in channel unification. Default: 2.

  • mlp_ratio (float) – Ratio of MLP hidden dimension to embedding dimension. Default: 4.0.

  • norm_layer (nn.Module) – Normalization layer class. Default: nn.LayerNorm.

  • drop_path (float) – Stochastic depth rate. Default: 0.0.

References

[LUNA]

Döner, B., Ingolfsson, T. M., Benini, L., & Li, Y. (2025). LUNA: Efficient and Topology-Agnostic Foundation Model for EEG Signal Analysis. The Thirty-Ninth Annual Conference on Neural Information Processing Systems - NeurIPS. Retrieved from https://openreview.net/forum?id=uazfjnFL0G

Methods

build_channel_location_template(num_channels)[source]#

Build channel location template for the model.

Attempts to extract channel locations from chs_info. Falls back to a default linear spacing along the x-axis if real locations are unavailable.

Parameters:

num_channels (int) – Number of channels to generate locations for.

Returns:

Tensor of shape (num_channels, 3) with channel locations in 3D space.

Return type:

torch.Tensor

fix_init_weight()[source]#
Return type:

None

forward(X, mask=None, channel_locations=None, channel_names=None)[source]#

Forward pass.

Return type:

Tensor

get_default_channel_locations(batch_size, num_channels, device, dtype)[source]#
Return type:

Tensor

initialize_weights()[source]#
Return type:

None

prepare_tokens(x_signal, channel_locations, mask=None)[source]#
Return type:

Tuple[Tensor, Tensor]