Fine-tuning a Foundation Model (Signal-JEPA)#

Foundation models are large-scale pre-trained models that serve as a starting point for a wide range of downstream tasks, leveraging their generalization capabilities. Fine-tuning these models is necessary to adapt them to specific tasks or datasets, ensuring optimal performance in specialized applications.

In this tutorial, we demonstrate how to load a pre-trained foundation model and fine-tune it for a specific task. We use the Signal-JEPA model [1] and a MOABB motor-imagery dataset for this tutorial.

# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)
#
import mne
import numpy as np
import torch

from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset
from braindecode.models import SignalJEPA_PreLocal
from braindecode.preprocessing import create_windows_from_events

torch.use_deterministic_algorithms(True)
torch.manual_seed(12)
np.random.seed(12)

Loading and preparing the data#

Loading a dataset#

We start by loading a MOABB dataset, a single subject only for speed. The dataset contains motor imagery EEG recordings, which we will preprocess and use for fine-tuning.

subject_id = 3  # Just one subject for speed
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])

# Set the standard 10-20 montage for EEG channel locations
montage = mne.channels.make_standard_montage("standard_1020")
for ds in dataset.datasets:
    ds.raw.set_montage(montage)

Preprocessing to match the pretrained model#

The pretrained SignalJEPA checkpoint expects 19 EEG channels at 128 Hz with 2-second windows. We adapt the dataset accordingly: keep only EEG channels, pick the first 19, and resample.

for ds in dataset.datasets:
    ds.raw.pick_types(eeg=True)  # drop EOG / stim channels
    ds.raw.pick(ds.raw.ch_names[:19])  # match pretrained channel count
    ds.raw.resample(128)  # match pretrained sampling frequency

Define Dataset parameters#

We extract the sampling frequency and channel information after preprocessing so they match the pretrained model.

sfreq = dataset.datasets[0].raw.info["sfreq"]
chs_info = dataset.datasets[0].raw.info["chs"]

print(f"{sfreq=}, {len(chs_info)=}")
sfreq=128.0, len(chs_info)=19

Create Windows from Events#

We use the create_windows_from_events function from Braindecode to segment the dataset into windows based on events.

classes = ["feet", "left_hand", "right_hand"]
classes_mapping = {c: i for i, c in enumerate(classes)}

windows_dataset = create_windows_from_events(
    dataset,
    preload=True,  # Preload the data into memory for faster processing
    mapping=classes_mapping,
    window_size_samples=256,  # 2 s at 128 Hz — matches pretrained model
    window_stride_samples=256,
)
metadata = windows_dataset.get_metadata()
print(metadata.head(10))
   i_window_in_trial  i_start_in_trial  i_stop_in_trial  ...  subject  session run
0                  0               384              640  ...        3   0train   0
1                  1               640              896  ...        3   0train   0
2                  0              1410             1666  ...        3   0train   0
3                  1              1666             1922  ...        3   0train   0
4                  0              2392             2648  ...        3   0train   0
5                  1              2648             2904  ...        3   0train   0
6                  0              3391             3647  ...        3   0train   0
7                  1              3647             3903  ...        3   0train   0
8                  0              4419             4675  ...        3   0train   0
9                  1              4675             4931  ...        3   0train   0

[10 rows x 7 columns]

Loading a pre-trained foundation model#

Load Pre-trained Weights from the Hub#

We load the pre-trained SignalJEPA downstream model from the Hugging Face Hub using from_pretrained. The SignalJEPA_PreLocal checkpoint already bundles the SSL backbone together with the downstream classification layers, so a single call is all that is needed.

For other foundation models (BENDR, BIOT, Labram, etc.) the same one-line pattern applies — see Loading and Adapting Pretrained Foundation Models.

model = SignalJEPA_PreLocal.from_pretrained(
    "braindecode/SignalJEPA-PreLocal-pretrained",
    n_outputs=len(classes),
)
print(model)
======================================================================================================================================================
Layer (type (var_name):depth-idx)                  Input Shape               Output Shape              Param #                   Kernel Shape
======================================================================================================================================================
SignalJEPA_PreLocal (SignalJEPA_PreLocal)          [1, 19, 256]              [1, 3]                    --                        --
├─Sequential (spatial_conv): 1-1                   [1, 19, 256]              [1, 4, 256]               --                        --
│    └─Rearrange (0): 2-1                          [1, 19, 256]              [1, 1, 19, 256]           --                        --
│    └─Conv2d (1): 2-2                             [1, 1, 19, 256]           [1, 4, 1, 256]            80                        [19, 1]
│    └─Rearrange (2): 2-3                          [1, 4, 1, 256]            [1, 4, 256]               --                        --
├─_ConvFeatureEncoder (feature_encoder): 1-2       [1, 4, 256]               [1, 4, 64]                --                        --
│    └─Rearrange (0): 2-4                          [1, 4, 256]               [4, 1, 256]               --                        --
│    └─Sequential (1): 2-5                         [4, 1, 256]               [4, 8, 29]                --                        --
│    │    └─Conv1d (0): 3-1                        [4, 1, 256]               [4, 8, 29]                256                       [32]
│    │    └─Dropout (1): 3-2                       [4, 8, 29]                [4, 8, 29]                --                        --
│    │    └─GroupNorm (2): 3-3                     [4, 8, 29]                [4, 8, 29]                16                        --
│    │    └─GELU (3): 3-4                          [4, 8, 29]                [4, 8, 29]                --                        --
│    └─Sequential (2): 2-6                         [4, 8, 29]                [4, 16, 14]               --                        --
│    │    └─Conv1d (0): 3-5                        [4, 8, 29]                [4, 16, 14]               256                       [2]
│    │    └─Dropout (1): 3-6                       [4, 16, 14]               [4, 16, 14]               --                        --
│    │    └─GELU (2): 3-7                          [4, 16, 14]               [4, 16, 14]               --                        --
│    └─Sequential (3): 2-7                         [4, 16, 14]               [4, 32, 7]                --                        --
│    │    └─Conv1d (0): 3-8                        [4, 16, 14]               [4, 32, 7]                1,024                     [2]
│    │    └─Dropout (1): 3-9                       [4, 32, 7]                [4, 32, 7]                --                        --
│    │    └─GELU (2): 3-10                         [4, 32, 7]                [4, 32, 7]                --                        --
│    └─Sequential (4): 2-8                         [4, 32, 7]                [4, 64, 3]                --                        --
│    │    └─Conv1d (0): 3-11                       [4, 32, 7]                [4, 64, 3]                4,096                     [2]
│    │    └─Dropout (1): 3-12                      [4, 64, 3]                [4, 64, 3]                --                        --
│    │    └─GELU (2): 3-13                         [4, 64, 3]                [4, 64, 3]                --                        --
│    └─Sequential (5): 2-9                         [4, 64, 3]                [4, 64, 1]                --                        --
│    │    └─Conv1d (0): 3-14                       [4, 64, 3]                [4, 64, 1]                8,192                     [2]
│    │    └─Dropout (1): 3-15                      [4, 64, 1]                [4, 64, 1]                --                        --
│    │    └─GELU (2): 3-16                         [4, 64, 1]                [4, 64, 1]                --                        --
│    └─Rearrange (6): 2-10                         [4, 64, 1]                [1, 4, 64]                --                        --
├─Sequential (final_layer): 1-3                    [1, 4, 64]                [1, 3]                    --                        --
│    └─Flatten (0): 2-11                           [1, 4, 64]                [1, 256]                  --                        --
│    └─Linear (1): 2-12                            [1, 256]                  [1, 3]                    771                       --
======================================================================================================================================================
Total params: 14,691
Trainable params: 14,691
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.18
======================================================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.05
Params size (MB): 0.06
Estimated Total Size (MB): 0.12
======================================================================================================================================================

Fine-tuning the Model#

Signal-JEPA is a model trained in a self-supervised manner on a masked prediction task. In this task, the model is configured in a many-to-many fashion, which is not suited for a classification task. Therefore, we need to adjust the model architecture for finetuning. This is what is done by the SignalJEPA_PreLocal, SignalJEPA_Contextual, and SignalJEPA_PostLocal classes. In these classes, new layers are added specifically for classification, as described in the article [1] and in the following figure:

Signal-JEPA Pre-Local Downstream Architecture

With this downstream architecture, two options are possible for fine-tuning:

  1. Fine-tune only the newly added layers

  2. Fine-tune the entire model

Freezing Pre-trained Layers#

As the second option is rather straightforward to implement, we will focus on the first option here. We will freeze all layers except the newly added ones.

# Keep the task-specific head layers (spatial_conv and final_layer)
# trainable and freeze the pretrained backbone.
new_layers = {
    name
    for name, _ in model.named_parameters()
    if name.startswith(("spatial_conv.", "final_layer."))
}

for name, param in model.named_parameters():
    if name not in new_layers:
        param.requires_grad = False

print("Trainable parameters:")
other_modules = set()
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
    else:
        other_modules.add(name.split(".")[0])

print("\nOther modules:")
print(other_modules)
Trainable parameters:
spatial_conv.1.weight
spatial_conv.1.bias
final_layer.1.weight
final_layer.1.bias

Other modules:
{'feature_encoder'}

Fine-tuning Procedure#

Finally, we set up the fine-tuning procedure using Braindecode’s EEGClassifier. We define the loss function, optimizer, and training parameters. We then fit the model to the windows dataset.

We only train for a few epochs for demonstration purposes.

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    optimizer__lr=0.005,
    batch_size=16,
    callbacks=["accuracy"],
    classes=range(3),
)
_ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.3329        1.1005       0.3353            0.3353        1.0986  0.2574
      2            0.3329        1.0992       0.3353            0.3353        1.0987  0.2261
      3            0.3343        1.0989       0.3410            0.3410        1.0985  0.2296
      4            0.3502        1.0985       0.3295            0.3295        1.0985  0.2280
      5            0.3372        1.0982       0.3295            0.3295        1.0985  0.2218
      6            0.4501        1.0972       0.3295            0.3295        1.0982  0.2218
      7            0.4153        1.0960       0.3584            0.3584        1.0979  0.2225
      8            0.4370        1.0949       0.3699            0.3699        1.0976  0.2242
      9            0.4486        1.0937       0.3468            0.3468        1.0974  0.2237
     10            0.4153        1.0921       0.3873            0.3873        1.0971  0.2233

All-in-one Implementation#

In the implementation above, we manually loaded the weights and froze the layers. This forces us to pass an initialized model to EEGClassifier, which may create issues if we use it in a cross-validation setting.

Instead, we can implement the same procedure in a more compact and reproducible way, by using skorch’s callback system.

Here, we import a callback to freeze layers and define a custom callback to load the pre-trained weights at the beginning of training:

from skorch.callbacks import Callback, Freezer


class WeightsLoader(Callback):
    def __init__(self, url, strict=False):
        self.url = url
        self.strict = strict

    def on_train_begin(self, net, X=None, y=None, **kwargs):
        state_dict = torch.hub.load_state_dict_from_url(url=self.url)
        net.module_.load_state_dict(state_dict, strict=self.strict)

We can now define a classifier with those callbacks, without having to pass an initialized model, and fit it as before:

clf = EEGClassifier(
    "SignalJEPA_PreLocal",
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    optimizer__lr=0.005,
    batch_size=16,
    callbacks=[
        "accuracy",
        WeightsLoader(
            url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
        ),
        Freezer(patterns="feature_encoder.*"),
    ],
    classes=range(3),
)
_ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)
Downloading: "https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth" to /home/runner/.cache/torch/hub/checkpoints/signal-jepa_16s-60_adeuwv4s.pth

  0%|          | 0.00/13.2M [00:00<?, ?B/s]
100%|██████████| 13.2M/13.2M [00:00<00:00, 225MB/s]
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.3329        1.0992       0.3353            0.3353        1.0986  0.2604
      2            0.3459        1.0994       0.3353            0.3353        1.0986  0.2218
      3            0.3357        1.0987       0.3295            0.3295        1.0986  0.2212
      4            0.4139        1.0982       0.3584            0.3584        1.0985  0.2210
      5            0.4269        1.0978       0.3410            0.3410        1.0983  0.2212
      6            0.4110        1.0967       0.3584            0.3584        1.0984  0.2204
      7            0.4081        1.0961       0.3006            0.3006        1.0982  0.2206
      8            0.4009        1.0942       0.3410            0.3410        1.0981  0.2203
      9            0.4530        1.0931       0.3584            0.3584        1.0979  0.2198
     10            0.4501        1.0913       0.3584            0.3584        1.0979  0.2212

Conclusion and Next Steps#

In this tutorial, we demonstrated how to fine-tune a pre-trained foundation model, Signal-JEPA, for a motor imagery classification task. We now have a basic implementation that can automatically load pre-trained weights and freeze specific layers.

This setup can easily be extended to explore different fine-tuning techniques, base foundation models, and downstream tasks.

References#

Total running time of the script: (0 minutes 16.790 seconds)

Estimated memory usage: 856 MB

Gallery generated by Sphinx-Gallery