Loading and Adapting Pretrained Foundation Models#

Braindecode + Hugging Face

All braindecode models can load and save weights on the Hugging Face Hub via from_pretrained and push_to_hub. For foundation models we additionally provide curated pretrained checkpoints with mapped weights, so you can fine-tune them out of the box.

Important

The curation and standardization of the pretrained foundation model weights available on the Braindecode Hugging Face organization was carried out as part of the OpenEEG-Bench benchmark [2]. If you use these pretrained weights, please cite the paper.

pip install braindecode[hub]

This tutorial shows how to load pretrained EEG foundation models, adapt them to new tasks, extract features, and save/restore full model configurations, using a unified API inspired by the Hugging Face transformers library.

# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
#
# License: BSD (3-clause)

import json
import os
import warnings

import torch
from huggingface_hub import login

warnings.simplefilter("ignore")

hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)

Loading a pretrained model#

All braindecode foundation models support the from_pretrained method, which downloads model weights and configuration from the Hugging Face Hub.

We start with BENDR [1] as an example:

from braindecode.models import BENDR

model = BENDR.from_pretrained("braindecode/braindecode-bendr", n_outputs=2)
print(f"Loaded BENDR with n_outputs={model.n_outputs}")
Loaded BENDR with n_outputs=2

The loaded model is ready for inference:

x = torch.randn(2, 20, 768)  # (batch, n_chans, n_times)
model.eval()
with torch.no_grad():
    out = model(x)
print(f"Output shape: {out.shape}")
Output shape: torch.Size([2, 2])

Adapting to a new task#

Changing the number of outputs#

When fine-tuning a pretrained model on a dataset with a different number of classes, pass n_outputs directly to from_pretrained. The backbone weights are loaded and the classification head is automatically rebuilt:

model = BENDR.from_pretrained("braindecode/braindecode-bendr", n_outputs=4)
print(f"n_outputs after loading: {model.n_outputs}")

with torch.no_grad():
    out = model(x)
print(f"Output shape with 4 classes: {out.shape}")
n_outputs after loading: 4
Output shape with 4 classes: torch.Size([2, 4])

You can also swap the head at any time using reset_head:

model.reset_head(10)
print(f"n_outputs after reset_head: {model.n_outputs}")

with torch.no_grad():
    out = model(x)
print(f"Output shape after reset: {out.shape}")
n_outputs after reset_head: 10
Output shape after reset: torch.Size([2, 10])

Extracting features#

All foundation models support return_features=True in their forward() method. This returns a dictionary with:

  • "features" – encoder embeddings before the classification head

  • "cls_token" – the CLS token (if the model has one, otherwise None)

model.eval()
with torch.no_grad():
    out = model(x, return_features=True)

print(f"Type: {type(out)}")
print(f"Features shape: {out['features'].shape}")
print(f"CLS token: {out['cls_token']}")
Type: <class 'dict'>
Features shape: torch.Size([2, 512])
CLS token: None

Tip

This is useful for transfer learning: freeze the backbone and train only a new head on the extracted features. See Fine-tuning a Foundation Model (Signal-JEPA) for a complete example.

Saving and restoring configurations#

get_config returns a JSON-serializable dictionary of all __init__ parameters (not just the 6 EEG-specific ones). This includes model-specific hyperparameters like encoder_h, drop_prob, activation, etc.

config = model.get_config()
print(json.dumps({k: v for k, v in config.items() if k != "chs_info"}, indent=2))
{
  "n_chans": 20,
  "n_outputs": 4,
  "n_times": null,
  "input_window_seconds": null,
  "sfreq": null,
  "encoder_h": 512,
  "contextualizer_hidden": 3076,
  "projection_head": false,
  "drop_prob": 0.1,
  "layer_drop": 0.0,
  "activation": "torch.nn.modules.activation.GELU",
  "transformer_layers": 8,
  "transformer_heads": 8,
  "position_encoder_length": 25,
  "enc_width": [
    3,
    2,
    2,
    2,
    2,
    2
  ],
  "enc_downsample": [
    3,
    2,
    2,
    2,
    2,
    2
  ],
  "start_token": -5,
  "final_layer": true,
  "encoder_only": false,
  "n_chans_pretrained": null,
  "chan_proj_max_norm": 1.0
}

You can reconstruct the model (without weights) using from_config:

model_copy = BENDR.from_config(config)
print(f"Reconstructed: n_outputs={model_copy.n_outputs}")
Reconstructed: n_outputs=4

When pushing to the Hub, the full config is saved automatically in config.json alongside the model weights:

model.push_to_hub("username/my-bendr-model")
# Later:
model = BENDR.from_pretrained("username/my-bendr-model")

Unified API across foundation models#

The same API works across all foundation models:

Model

from_pretrained

reset_head

return_features

get_config

BENDR

BIOT

CBraMod

EEGPT

Labram

✔ (+ cls_token)

LUNA

REVE

SignalJEPA variants

The feature shapes differ between models (reflecting their architecture), but the API is always the same.

Available pretrained weights#

Hugging Face

The figure below shows all available pretrained checkpoints ranked by parameter count, with colors indicating the hosting organization. Parameter counts are read directly from the Hub.

import matplotlib.pyplot as plt
import numpy as np

from braindecode.models import (
    BENDR,
    BIOT,
    EEGPT,
    LUNA,
    REVE,
    CBraMod,
    Labram,
    SignalJEPA,
)

# (display_name, model_class, from_pretrained kwargs, org)
checkpoints = [
    (
        "BENDR",
        BENDR,
        dict(
            pretrained_model_name_or_path="braindecode/braindecode-bendr", n_outputs=2
        ),
        "braindecode",
    ),
    (
        "BIOT (16ch)",
        BIOT,
        dict(pretrained_model_name_or_path="braindecode/biot-pretrained-prest-16chs"),
        "braindecode",
    ),
    (
        "BIOT (18ch)",
        BIOT,
        dict(
            pretrained_model_name_or_path="braindecode/biot-pretrained-shhs-prest-18chs"
        ),
        "braindecode",
    ),
    (
        "CBraMod",
        CBraMod,
        dict(
            pretrained_model_name_or_path="braindecode/cbramod-pretrained",
            n_outputs=2,
            n_chans=22,
            n_times=1000,
            sfreq=250,
        ),
        "braindecode",
    ),
    (
        "EEGPT",
        EEGPT,
        dict(
            pretrained_model_name_or_path="braindecode/eegpt-pretrained",
            n_chans=62,
            chan_proj_type="none",
        ),
        "braindecode",
    ),
    (
        "Labram",
        Labram,
        dict(
            pretrained_model_name_or_path="braindecode/labram-pretrained", n_chans=128
        ),
        "braindecode",
    ),
    (
        "SignalJEPA",
        SignalJEPA,
        dict(
            pretrained_model_name_or_path="braindecode/SignalJEPA-pretrained",
            n_outputs=2,
            n_chans=19,
            input_window_seconds=5,
            sfreq=256,
        ),
        "braindecode",
    ),
    (
        "REVE (base)",
        REVE,
        dict(
            pretrained_model_name_or_path="brain-bzh/reve-base",
            n_outputs=2,
            n_chans=64,
            n_times=512,
            sfreq=256,
        ),
        "brain-bzh",
    ),
    (
        "REVE (large)",
        REVE,
        dict(
            pretrained_model_name_or_path="brain-bzh/reve-large",
            n_outputs=2,
            n_chans=64,
            n_times=512,
            sfreq=256,
        ),
        "brain-bzh",
    ),
    (
        "LUNA (base)",
        LUNA,
        dict(
            pretrained_model_name_or_path="PulpBio/LUNA",
            filename="LUNA_base.safetensors",
            n_outputs=2,
            n_chans=22,
            n_times=1000,
            embed_dim=64,
            num_queries=4,
            depth=8,
        ),
        "PulpBio",
    ),
    (
        "LUNA (large)",
        LUNA,
        dict(
            pretrained_model_name_or_path="PulpBio/LUNA",
            filename="LUNA_large.safetensors",
            n_outputs=2,
            n_chans=22,
            n_times=1000,
            embed_dim=96,
            num_queries=6,
            depth=10,
        ),
        "PulpBio",
    ),
    (
        "LUNA (huge)",
        LUNA,
        dict(
            pretrained_model_name_or_path="PulpBio/LUNA",
            filename="LUNA_huge.safetensors",
            n_outputs=2,
            n_chans=22,
            n_times=1000,
            embed_dim=128,
            num_queries=8,
            depth=24,
        ),
        "PulpBio",
    ),
]

# Skip gated models when no HF token is available (e.g. fork PRs)
if not hf_token:
    checkpoints = [(d, c, k, o) for d, c, k, o in checkpoints if o != "brain-bzh"]

names, params_m, orgs = [], [], []
for display, cls, kwargs, org in checkpoints:
    mdl = cls.from_pretrained(**kwargs)
    n_params = sum(
        p.numel()
        for p in mdl.parameters()
        if not isinstance(p, torch.nn.UninitializedParameter)
    )
    names.append(display)
    params_m.append(n_params / 1e6)
    orgs.append(org)
    print(f"  {display:15s}  {n_params / 1e6:8.1f}M params")

params_m = np.array(params_m)

# Sort by parameter count (ascending) for horizontal bar chart
order = np.argsort(params_m)
names = [names[i] for i in order]
params_m = params_m[order]
orgs = [orgs[i] for i in order]

from matplotlib.ticker import FuncFormatter

# -- palette ----------------------------------------------------------
org_palette = {
    "braindecode": "#2D6A9F",
    "brain-bzh": "#C04E3E",
    "PulpBio": "#4A8B6F",
}
org_markers = {"braindecode": "o", "brain-bzh": "D", "PulpBio": "s"}
colors = [org_palette[o] for o in orgs]
markers = [org_markers[o] for o in orgs]

# -- figure setup ------------------------------------------------------
fig, ax = plt.subplots(figsize=(7.5, 5.0), dpi=120)
fig.patch.set_facecolor("#FAFAFA")
ax.set_facecolor("#FAFAFA")

y_pos = np.arange(len(names))

# Horizontal reference lines (subtle, behind everything)
for y in y_pos:
    ax.axhline(y, color="#E8E8E8", linewidth=0.6, zorder=0)

# Stem lines from left edge to dot
ax.hlines(y_pos, 0.5, params_m, color="#D0D0D0", linewidth=0.9, zorder=1)

# Dots (different shapes per org for accessibility)
for i in range(len(names)):
    ax.scatter(
        params_m[i],
        y_pos[i],
        color=colors[i],
        s=70,
        zorder=3,
        marker=markers[i],
        edgecolors="white",
        linewidths=0.8,
    )

# Value labels always to the right of the dot
for i, pm in enumerate(params_m):
    ax.text(
        pm * 1.12,
        y_pos[i],
        f"{pm:.1f}M",
        ha="left",
        va="center",
        fontsize=7.8,
        color="#444444",
        fontweight="bold",
        zorder=4,
    )

# Y-axis
ax.set_yticks(y_pos)
ax.set_yticklabels(names, fontsize=9, color="#333333", fontweight="medium")

# X-axis: log scale with human-readable labels
ax.set_xscale("log")
ax.set_xlim(1.5, params_m.max() * 2.5)
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.0f}M"))
ax.set_xticks([3, 10, 30, 100, 300])
ax.tick_params(axis="x", colors="#999999", labelsize=8, length=0, pad=4)
ax.tick_params(axis="y", length=0)

# Title + subtitle
ax.set_title(
    "EEG Foundation Models",
    fontsize=13,
    fontweight="bold",
    color="#1a1a1a",
    loc="left",
    pad=22,
)
ax.text(
    0.0,
    1.02,
    "Parameter count of curated checkpoints on Hugging Face Hub",
    transform=ax.transAxes,
    fontsize=8.5,
    color="#777777",
    va="bottom",
)

# Organization legend with shape + color
for org_name in org_palette:
    ax.scatter(
        [],
        [],
        c=org_palette[org_name],
        s=40,
        marker=org_markers[org_name],
        label=org_name,
        edgecolors="white",
    )
legend = ax.legend(
    fontsize=7.5,
    loc="lower right",
    frameon=True,
    framealpha=0.95,
    edgecolor="#E0E0E0",
    handletextpad=0.4,
    borderpad=0.8,
    labelspacing=0.7,
)
legend.get_frame().set_facecolor("#FAFAFA")

# Clean spines
for spine in ax.spines.values():
    spine.set_visible(False)

# Minimal grid on x only
ax.xaxis.grid(True, color="#E8E8E8", linewidth=0.5, which="major")
ax.yaxis.grid(False)
ax.set_axisbelow(True)

fig.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
EEG Foundation Models
BENDR               157.1M params
BIOT (16ch)           3.2M params
BIOT (18ch)           3.2M params
CBraMod               5.0M params
EEGPT                25.3M params
Labram                5.8M params
SignalJEPA            3.5M params
REVE (base)          69.5M params
REVE (large)        390.6M params
LUNA (base)           7.1M params
LUNA (large)         43.1M params
LUNA (huge)         311.8M params

Braindecode organization#

The braindecode organization on Hugging Face re-hosts the official pretrained weights. All models below follow the same one-line loading pattern:

model = Model.from_pretrained("<repo-id>", n_outputs=...)

Model

Hub Repository

Details

BENDR

braindecode/braindecode-bendr

20 channels

BIOT

braindecode/biot-pretrained-prest-16chs

16 ch, PREST

BIOT

braindecode/biot-pretrained-shhs-prest-18chs

18 ch, SHHS + PREST

BIOT

braindecode/biot-pretrained-six-datasets-18chs

18 ch, 6 datasets

CBraMod

braindecode/cbramod-pretrained

channel-agnostic

EEGPT

braindecode/eegpt-pretrained

62 ch, 250 Hz

Labram

braindecode/labram-pretrained

128 channels

SignalJEPA

braindecode/SignalJEPA-pretrained

19 channels

SignalJEPA_Contextual

braindecode/SignalJEPA-Contextual-pretrained

19 channels

SignalJEPA_PostLocal

braindecode/SignalJEPA-PostLocal-pretrained

19 channels

SignalJEPA_PreLocal

braindecode/SignalJEPA-PreLocal-pretrained

19 channels

External organizations#

Some pretrained weights are hosted by the original model authors:

Model

Organization

Hub Repository

Details

REVE

brain-bzh

brain-bzh/reve-base

69M params

REVE

brain-bzh

brain-bzh/reve-large

400M params

LUNA

PulpBio

PulpBio/LUNA

base / large / huge

Note

Loading LUNA: This repo stores multiple weight variants in a single repository. Use the filename parameter to select one:

from braindecode.models import LUNA

model = LUNA.from_pretrained(
    "PulpBio/LUNA",
    filename="LUNA_base.safetensors",
    n_outputs=2, n_chans=22, n_times=1000,
    embed_dim=64, num_queries=4, depth=8,
)

Available files: LUNA_base.safetensors (7M), LUNA_large.safetensors (43M), LUNA_huge.safetensors (311M).

References#

Total running time of the script: (0 minutes 40.850 seconds)

Estimated memory usage: 4284 MB

Gallery generated by Sphinx-Gallery