Fine-tuning a Foundation Model (Signal-JEPA)#

Foundation models are large-scale pre-trained models that serve as a starting point for a wide range of downstream tasks, leveraging their generalization capabilities. Fine-tuning these models is necessary to adapt them to specific tasks or datasets, ensuring optimal performance in specialized applications.

In this tutorial, we demonstrate how to load a pre-trained foundation model and fine-tune it for a specific task. We use the Signal-JEPA model [1] and a MOABB motor-imagery dataset for this tutorial.

# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)
#
import mne
import numpy as np
import torch

from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset
from braindecode.models import SignalJEPA_PreLocal
from braindecode.preprocessing import create_windows_from_events

torch.use_deterministic_algorithms(True)
torch.manual_seed(12)
np.random.seed(12)

Loading and preparing the data#

Loading a dataset#

We start by loading a MOABB dataset, a single subject only for speed. The dataset contains motor imagery EEG recordings, which we will preprocess and use for fine-tuning.

subject_id = 3  # Just one subject for speed
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])

# Set the standard 10-20 montage for EEG channel locations
montage = mne.channels.make_standard_montage("standard_1020")
for ds in dataset.datasets:
    ds.raw.set_montage(montage)

Define Dataset parameters#

We extract the sampling frequency and ensure that it is consistent across all recordings. We also extract the window size from the annotations and information about the EEG channels (names, positions, etc.).

# Extract sampling frequency
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])

# Extract and validate window size from annotations
window_size_seconds = dataset.datasets[0].raw.annotations.duration[0]
assert all(
    d == window_size_seconds
    for ds in dataset.datasets
    for d in ds.raw.annotations.duration
)

# Extract channel information
chs_info = dataset.datasets[0].raw.info["chs"]  # Channel information

print(f"{sfreq=}, {window_size_seconds=}, {len(chs_info)=}")
sfreq=250.0, window_size_seconds=4.0, len(chs_info)=26

Create Windows from Events#

We use the create_windows_from_events function from Braindecode to segment the dataset into windows based on events.

classes = ["feet", "left_hand", "right_hand"]
classes_mapping = {c: i for i, c in enumerate(classes)}

windows_dataset = create_windows_from_events(
    dataset,
    preload=True,  # Preload the data into memory for faster processing
    mapping=classes_mapping,
)
metadata = windows_dataset.get_metadata()
print(metadata.head(10))
   i_window_in_trial  i_start_in_trial  i_stop_in_trial  ...  subject  session run
0                  0               750             1750  ...        3   0train   0
1                  0              2753             3753  ...        3   0train   0
2                  0              4671             5671  ...        3   0train   0
3                  0              6623             7623  ...        3   0train   0
4                  0              8631             9631  ...        3   0train   0
5                  0             10742            11742  ...        3   0train   0
6                  0             12659            13659  ...        3   0train   0
7                  0             14709            15709  ...        3   0train   0
8                  0             16640            17640  ...        3   0train   0
9                  0             20544            21544  ...        3   0train   0

[10 rows x 7 columns]

Loading a pre-trained foundation model#

Download and Load Pre-trained Weights#

We download the pre-trained weights for the SignalJEPA model from the Hugging Face Hub. These weights will serve as the starting point for finetuning.

model_state_dict = torch.hub.load_state_dict_from_url(
    url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
)
# print(model_state_dict.keys())
Downloading: "https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth" to /home/runner/.cache/torch/hub/checkpoints/signal-jepa_16s-60_adeuwv4s.pth

  0%|          | 0.00/13.2M [00:00<?, ?B/s]
 37%|███▋      | 4.88M/13.2M [00:00<00:00, 51.0MB/s]
100%|██████████| 13.2M/13.2M [00:00<00:00, 108MB/s]

Instantiate the Foundation Model#

We create an instance of the SignalJEPA model using the pre-local downstream architecture. The model is initialized with the dataset’s sampling frequency, window size, and channel information.

model = SignalJEPA_PreLocal(
    sfreq=sfreq,
    input_window_seconds=window_size_seconds,
    chs_info=chs_info,
    n_outputs=len(classes),
)
print(model)
======================================================================================================================================================
Layer (type (var_name):depth-idx)                  Input Shape               Output Shape              Param #                   Kernel Shape
======================================================================================================================================================
SignalJEPA_PreLocal (SignalJEPA_PreLocal)          [1, 26, 1000]             [1, 3]                    --                        --
├─Sequential (spatial_conv): 1-1                   [1, 26, 1000]             [1, 4, 1000]              --                        --
│    └─Rearrange (0): 2-1                          [1, 26, 1000]             [1, 1, 26, 1000]          --                        --
│    └─Conv2d (1): 2-2                             [1, 1, 26, 1000]          [1, 4, 1, 1000]           108                       [26, 1]
│    └─Rearrange (2): 2-3                          [1, 4, 1, 1000]           [1, 4, 1000]              --                        --
├─_ConvFeatureEncoder (feature_encoder): 1-2       [1, 4, 1000]              [1, 28, 64]               --                        --
│    └─Rearrange (0): 2-4                          [1, 4, 1000]              [4, 1, 1000]              --                        --
│    └─Sequential (1): 2-5                         [4, 1, 1000]              [4, 8, 122]               --                        --
│    │    └─Conv1d (0): 3-1                        [4, 1, 1000]              [4, 8, 122]               256                       [32]
│    │    └─Dropout (1): 3-2                       [4, 8, 122]               [4, 8, 122]               --                        --
│    │    └─GroupNorm (2): 3-3                     [4, 8, 122]               [4, 8, 122]               16                        --
│    │    └─GELU (3): 3-4                          [4, 8, 122]               [4, 8, 122]               --                        --
│    └─Sequential (2): 2-6                         [4, 8, 122]               [4, 16, 61]               --                        --
│    │    └─Conv1d (0): 3-5                        [4, 8, 122]               [4, 16, 61]               256                       [2]
│    │    └─Dropout (1): 3-6                       [4, 16, 61]               [4, 16, 61]               --                        --
│    │    └─GELU (2): 3-7                          [4, 16, 61]               [4, 16, 61]               --                        --
│    └─Sequential (3): 2-7                         [4, 16, 61]               [4, 32, 30]               --                        --
│    │    └─Conv1d (0): 3-8                        [4, 16, 61]               [4, 32, 30]               1,024                     [2]
│    │    └─Dropout (1): 3-9                       [4, 32, 30]               [4, 32, 30]               --                        --
│    │    └─GELU (2): 3-10                         [4, 32, 30]               [4, 32, 30]               --                        --
│    └─Sequential (4): 2-8                         [4, 32, 30]               [4, 64, 15]               --                        --
│    │    └─Conv1d (0): 3-11                       [4, 32, 30]               [4, 64, 15]               4,096                     [2]
│    │    └─Dropout (1): 3-12                      [4, 64, 15]               [4, 64, 15]               --                        --
│    │    └─GELU (2): 3-13                         [4, 64, 15]               [4, 64, 15]               --                        --
│    └─Sequential (5): 2-9                         [4, 64, 15]               [4, 64, 7]                --                        --
│    │    └─Conv1d (0): 3-14                       [4, 64, 15]               [4, 64, 7]                8,192                     [2]
│    │    └─Dropout (1): 3-15                      [4, 64, 7]                [4, 64, 7]                --                        --
│    │    └─GELU (2): 3-16                         [4, 64, 7]                [4, 64, 7]                --                        --
│    └─Rearrange (6): 2-10                         [4, 64, 7]                [1, 28, 64]               --                        --
├─Sequential (final_layer): 1-3                    [1, 28, 64]               [1, 3]                    --                        --
│    └─Flatten (0): 2-11                           [1, 28, 64]               [1, 1792]                 --                        --
│    └─Linear (1): 2-12                            [1, 1792]                 [1, 3]                    5,379                     --
======================================================================================================================================================
Total params: 19,327
Trainable params: 19,327
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.90
======================================================================================================================================================
Input size (MB): 0.10
Forward/backward pass size (MB): 0.20
Params size (MB): 0.08
Estimated Total Size (MB): 0.38
======================================================================================================================================================

Load the Pre-trained Weights into the Model#

We load the pre-trained weights into the model. The transformer layers are excluded as this module is not used in the pre-local downstream architecture (see [1]).

# Define layers to exclude from the pre-trained weights
new_layers = {
    "spatial_conv.1.weight",
    "spatial_conv.1.bias",
    "final_layer.1.weight",
    "final_layer.1.bias",
}

# Filter out transformer weights and load the state dictionary
model_state_dict = {
    k: v for k, v in model_state_dict.items() if not k.startswith("transformer.")
}
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)

# Ensure no unexpected keys and validate missing keys
assert unexpected_keys == [], f"{unexpected_keys=}"
assert set(missing_keys) == new_layers, f"{missing_keys=}"

Fine-tuning the Model#

Signal-JEPA is a model trained in a self-supervised manner on a masked prediction task. In this task, the model is configured in a many-to-many fashion, which is not suited for a classification task. Therefore, we need to adjust the model architecture for finetuning. This is what is done by the SignalJEPA_PreLocal, SignalJEPA_Contextual, and SignalJEPA_PostLocal classes. In these classes, new layers are added specifically for classification, as described in the article [1] and in the following figure:

Signal-JEPA Pre-Local Downstream Architecture

With this downstream architecture, two options are possible for fine-tuning:

  1. Fine-tune only the newly added layers

  2. Fine-tune the entire model

Freezing Pre-trained Layers#

As the second option is rather straightforward to implement, we will focus on the first option here. We will freeze all layers except the newly added ones.

for name, param in model.named_parameters():
    if name not in new_layers:
        param.requires_grad = False

print("Trainable parameters:")
other_modules = set()
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
    else:
        other_modules.add(name.split(".")[0])

print("\nOther modules:")
print(other_modules)
Trainable parameters:
spatial_conv.1.weight
spatial_conv.1.bias
final_layer.1.weight
final_layer.1.bias

Other modules:
{'feature_encoder'}

Fine-tuning Procedure#

Finally, we set up the fine-tuning procedure using Braindecode’s EEGClassifier. We define the loss function, optimizer, and training parameters. We then fit the model to the windows dataset.

We only train for a few epochs for demonstration purposes.

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    optimizer__lr=0.005,
    batch_size=16,
    callbacks=["accuracy"],
    classes=range(3),
)
_ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.3333        1.0994       0.3333            0.3333        1.0989  0.2266
      2            0.3333        1.0992       0.3333            0.3333        1.0987  0.1970
      3            0.4696        1.0985       0.3448            0.3448        1.0984  0.1941
      4            0.5217        1.0961       0.3563            0.3563        1.0979  0.1937
      5            0.5768        1.0935       0.3793            0.3793        1.0971  0.1946
      6            0.6319        1.0887       0.3448            0.3448        1.0963  0.1942
      7            0.5826        1.0827       0.3563            0.3563        1.0951  0.1918
      8            0.6348        1.0766       0.3678            0.3678        1.0939  0.1924
      9            0.6174        1.0686       0.3793            0.3793        1.0925  0.1934
     10            0.6290        1.0618       0.3793            0.3793        1.0913  0.1970

All-in-one Implementation#

In the implementation above, we manually loaded the weights and froze the layers. This forces us to pass an initialized model to EEGClassifier, which may create issues if we use it in a cross-validation setting.

Instead, we can implement the same procedure in a more compact and reproducible way, by using skorch’s callback system.

Here, we import a callback to freeze layers and define a custom callback to load the pre-trained weights at the beginning of training:

from skorch.callbacks import Callback, Freezer


class WeightsLoader(Callback):
    def __init__(self, url, strict=False):
        self.url = url
        self.strict = strict

    def on_train_begin(self, net, X=None, y=None, **kwargs):
        state_dict = torch.hub.load_state_dict_from_url(url=self.url)
        net.module_.load_state_dict(state_dict, strict=self.strict)

We can now define a classifier with those callbacks, without having to pass an initialized model, and fit it as before:

clf = EEGClassifier(
    "SignalJEPA_PreLocal",
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.AdamW,
    optimizer__lr=0.005,
    batch_size=16,
    callbacks=[
        "accuracy",
        WeightsLoader(
            url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
        ),
        Freezer(patterns="feature_encoder.*"),
    ],
    classes=range(3),
)
_ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)
  epoch    train_accuracy    train_loss    valid_acc    valid_accuracy    valid_loss     dur
-------  ----------------  ------------  -----------  ----------------  ------------  ------
      1            0.3333        1.0997       0.3333            0.3333        1.0987  0.2218
      2            0.3362        1.0987       0.3333            0.3333        1.0985  0.1919
      3            0.3391        1.0980       0.3448            0.3448        1.0983  0.1929
      4            0.3855        1.0963       0.3448            0.3448        1.0978  0.1944
      5            0.4986        1.0936       0.4023            0.4023        1.0972  0.1937
      6            0.5739        1.0898       0.3563            0.3563        1.0965  0.1922
      7            0.5014        1.0852       0.3908            0.3908        1.0954  0.1926
      8            0.5652        1.0784       0.4138            0.4138        1.0942  0.1936
      9            0.5710        1.0717       0.4023            0.4023        1.0929  0.1924
     10            0.6174        1.0629       0.4138            0.4138        1.0914  0.1915

Conclusion and Next Steps#

In this tutorial, we demonstrated how to fine-tune a pre-trained foundation model, Signal-JEPA, for a motor imagery classification task. We now have a basic implementation that can automatically load pre-trained weights and freeze specific layers.

This setup can easily be extended to explore different fine-tuning techniques, base foundation models, and downstream tasks.

References#

Total running time of the script: (0 minutes 9.508 seconds)

Estimated memory usage: 1163 MB

Gallery generated by Sphinx-Gallery