Note
Go to the end to download the full example code.
Fine-tuning a Foundation Model (Signal-JEPA)#
Foundation models are large-scale pre-trained models that serve as a starting point for a wide range of downstream tasks, leveraging their generalization capabilities. Fine-tuning these models is necessary to adapt them to specific tasks or datasets, ensuring optimal performance in specialized applications.
In this tutorial, we demonstrate how to load a pre-trained foundation model and fine-tune it for a specific task. We use the Signal-JEPA model [1] and a MOABB motor-imagery dataset for this tutorial.
# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)
#
import mne
import numpy as np
import torch
from braindecode import EEGClassifier
from braindecode.datasets import MOABBDataset
from braindecode.models import SignalJEPA_PreLocal
from braindecode.preprocessing import create_windows_from_events
torch.use_deterministic_algorithms(True)
torch.manual_seed(12)
np.random.seed(12)
Loading and preparing the data#
Loading a dataset#
We start by loading a MOABB dataset, a single subject only for speed. The dataset contains motor imagery EEG recordings, which we will preprocess and use for fine-tuning.
subject_id = 3 # Just one subject for speed
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])
# Set the standard 10-20 montage for EEG channel locations
montage = mne.channels.make_standard_montage("standard_1020")
for ds in dataset.datasets:
ds.raw.set_montage(montage)
Define Dataset parameters#
We extract the sampling frequency and ensure that it is consistent across all recordings. We also extract the window size from the annotations and information about the EEG channels (names, positions, etc.).
# Extract sampling frequency
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Extract and validate window size from annotations
window_size_seconds = dataset.datasets[0].raw.annotations.duration[0]
assert all(
d == window_size_seconds
for ds in dataset.datasets
for d in ds.raw.annotations.duration
)
# Extract channel information
chs_info = dataset.datasets[0].raw.info["chs"] # Channel information
print(f"{sfreq=}, {window_size_seconds=}, {len(chs_info)=}")
sfreq=250.0, window_size_seconds=np.float64(4.0), len(chs_info)=26
Create Windows from Events#
We use the create_windows_from_events function from Braindecode to segment the dataset into windows based on events.
classes = ["feet", "left_hand", "right_hand"]
classes_mapping = {c: i for i, c in enumerate(classes)}
windows_dataset = create_windows_from_events(
dataset,
preload=True, # Preload the data into memory for faster processing
mapping=classes_mapping,
)
metadata = windows_dataset.get_metadata()
print(metadata.head(10))
i_window_in_trial i_start_in_trial i_stop_in_trial ... subject session run
0 0 750 1750 ... 3 0train 0
1 0 2753 3753 ... 3 0train 0
2 0 4671 5671 ... 3 0train 0
3 0 6623 7623 ... 3 0train 0
4 0 8631 9631 ... 3 0train 0
5 0 10742 11742 ... 3 0train 0
6 0 12659 13659 ... 3 0train 0
7 0 14709 15709 ... 3 0train 0
8 0 16640 17640 ... 3 0train 0
9 0 20544 21544 ... 3 0train 0
[10 rows x 7 columns]
Loading a pre-trained foundation model#
Download and Load Pre-trained Weights#
We download the pre-trained weights for the SignalJEPA model from the Hugging Face Hub. These weights will serve as the starting point for finetuning.
model_state_dict = torch.hub.load_state_dict_from_url(
url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
)
# print(model_state_dict.keys())
Downloading: "https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth" to /home/runner/.cache/torch/hub/checkpoints/signal-jepa_16s-60_adeuwv4s.pth
0%| | 0.00/13.2M [00:00<?, ?B/s]
100%|██████████| 13.2M/13.2M [00:00<00:00, 298MB/s]
Instantiate the Foundation Model#
We create an instance of the SignalJEPA model using the pre-local downstream architecture. The model is initialized with the dataset’s sampling frequency, window size, and channel information.
model = SignalJEPA_PreLocal(
sfreq=sfreq,
input_window_seconds=window_size_seconds,
chs_info=chs_info,
n_outputs=len(classes),
)
print(model)
======================================================================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape Param # Kernel Shape
======================================================================================================================================================
SignalJEPA_PreLocal (SignalJEPA_PreLocal) [1, 26, 1000] [1, 3] -- --
├─Sequential (spatial_conv): 1-1 [1, 26, 1000] [1, 4, 1000] -- --
│ └─Rearrange (0): 2-1 [1, 26, 1000] [1, 1, 26, 1000] -- --
│ └─Conv2d (1): 2-2 [1, 1, 26, 1000] [1, 4, 1, 1000] 108 [26, 1]
│ └─Rearrange (2): 2-3 [1, 4, 1, 1000] [1, 4, 1000] -- --
├─_ConvFeatureEncoder (feature_encoder): 1-2 [1, 4, 1000] [1, 28, 64] -- --
│ └─Rearrange (0): 2-4 [1, 4, 1000] [4, 1, 1000] -- --
│ └─Sequential (1): 2-5 [4, 1, 1000] [4, 8, 122] -- --
│ │ └─Conv1d (0): 3-1 [4, 1, 1000] [4, 8, 122] 256 [32]
│ │ └─Dropout (1): 3-2 [4, 8, 122] [4, 8, 122] -- --
│ │ └─GroupNorm (2): 3-3 [4, 8, 122] [4, 8, 122] 16 --
│ │ └─GELU (3): 3-4 [4, 8, 122] [4, 8, 122] -- --
│ └─Sequential (2): 2-6 [4, 8, 122] [4, 16, 61] -- --
│ │ └─Conv1d (0): 3-5 [4, 8, 122] [4, 16, 61] 256 [2]
│ │ └─Dropout (1): 3-6 [4, 16, 61] [4, 16, 61] -- --
│ │ └─GELU (2): 3-7 [4, 16, 61] [4, 16, 61] -- --
│ └─Sequential (3): 2-7 [4, 16, 61] [4, 32, 30] -- --
│ │ └─Conv1d (0): 3-8 [4, 16, 61] [4, 32, 30] 1,024 [2]
│ │ └─Dropout (1): 3-9 [4, 32, 30] [4, 32, 30] -- --
│ │ └─GELU (2): 3-10 [4, 32, 30] [4, 32, 30] -- --
│ └─Sequential (4): 2-8 [4, 32, 30] [4, 64, 15] -- --
│ │ └─Conv1d (0): 3-11 [4, 32, 30] [4, 64, 15] 4,096 [2]
│ │ └─Dropout (1): 3-12 [4, 64, 15] [4, 64, 15] -- --
│ │ └─GELU (2): 3-13 [4, 64, 15] [4, 64, 15] -- --
│ └─Sequential (5): 2-9 [4, 64, 15] [4, 64, 7] -- --
│ │ └─Conv1d (0): 3-14 [4, 64, 15] [4, 64, 7] 8,192 [2]
│ │ └─Dropout (1): 3-15 [4, 64, 7] [4, 64, 7] -- --
│ │ └─GELU (2): 3-16 [4, 64, 7] [4, 64, 7] -- --
│ └─Rearrange (6): 2-10 [4, 64, 7] [1, 28, 64] -- --
├─Sequential (final_layer): 1-3 [1, 28, 64] [1, 3] -- --
│ └─Flatten (0): 2-11 [1, 28, 64] [1, 1792] -- --
│ └─Linear (1): 2-12 [1, 1792] [1, 3] 5,379 --
======================================================================================================================================================
Total params: 19,327
Trainable params: 19,327
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.90
======================================================================================================================================================
Input size (MB): 0.10
Forward/backward pass size (MB): 0.20
Params size (MB): 0.08
Estimated Total Size (MB): 0.38
======================================================================================================================================================
Load the Pre-trained Weights into the Model#
We load the pre-trained weights into the model. The transformer layers are excluded as this module is not used in the pre-local downstream architecture (see [1]).
# Define layers to exclude from the pre-trained weights
new_layers = {
"spatial_conv.1.weight",
"spatial_conv.1.bias",
"final_layer.1.weight",
"final_layer.1.bias",
}
# Filter out transformer weights and load the state dictionary
model_state_dict = {
k: v for k, v in model_state_dict.items() if not k.startswith("transformer.")
}
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
# Ensure no unexpected keys and validate missing keys
assert unexpected_keys == [], f"{unexpected_keys=}"
assert set(missing_keys) == new_layers, f"{missing_keys=}"
Fine-tuning the Model#
Signal-JEPA is a model trained in a self-supervised manner on a masked
prediction task. In this task, the model is configured in a many-to-many
fashion, which is not suited for a classification task. Therefore, we need to
adjust the model architecture for finetuning. This is what is done by the
SignalJEPA_PreLocal, SignalJEPA_Contextual, and
SignalJEPA_PostLocal classes. In these classes, new layers are added
specifically for classification, as described in the article [1] and in the following figure:
With this downstream architecture, two options are possible for fine-tuning:
Fine-tune only the newly added layers
Fine-tune the entire model
Freezing Pre-trained Layers#
As the second option is rather straightforward to implement, we will focus on the first option here. We will freeze all layers except the newly added ones.
for name, param in model.named_parameters():
if name not in new_layers:
param.requires_grad = False
print("Trainable parameters:")
other_modules = set()
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
else:
other_modules.add(name.split(".")[0])
print("\nOther modules:")
print(other_modules)
Trainable parameters:
spatial_conv.1.weight
spatial_conv.1.bias
final_layer.1.weight
final_layer.1.bias
Other modules:
{'feature_encoder'}
Fine-tuning Procedure#
Finally, we set up the fine-tuning procedure using Braindecode’s
EEGClassifier. We define the loss function, optimizer, and training
parameters. We then fit the model to the windows dataset.
We only train for a few epochs for demonstration purposes.
clf = EEGClassifier(
model,
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.AdamW,
optimizer__lr=0.005,
batch_size=16,
callbacks=["accuracy"],
classes=range(3),
)
_ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.3333 1.0994 0.3333 0.3333 1.0989 0.2404
2 0.3333 1.0992 0.3333 0.3333 1.0987 0.2034
3 0.4696 1.0985 0.3448 0.3448 1.0984 0.2039
4 0.5217 1.0961 0.3563 0.3563 1.0979 0.1986
5 0.5768 1.0935 0.3793 0.3793 1.0971 0.2041
6 0.6319 1.0887 0.3448 0.3448 1.0963 0.2381
7 0.5826 1.0827 0.3563 0.3563 1.0951 0.2054
8 0.6348 1.0766 0.3678 0.3678 1.0939 0.2011
9 0.6174 1.0686 0.3793 0.3793 1.0925 0.1982
10 0.6290 1.0618 0.3793 0.3793 1.0913 0.2005
All-in-one Implementation#
In the implementation above, we manually loaded the weights and froze the layers.
This forces us to pass an initialized model to EEGClassifier, which may
create issues if we use it in a cross-validation setting.
Instead, we can implement the same procedure in a more compact and reproducible way, by using skorch’s callback system.
Here, we import a callback to freeze layers and define a custom callback to load the pre-trained weights at the beginning of training:
from skorch.callbacks import Callback, Freezer
class WeightsLoader(Callback):
def __init__(self, url, strict=False):
self.url = url
self.strict = strict
def on_train_begin(self, net, X=None, y=None, **kwargs):
state_dict = torch.hub.load_state_dict_from_url(url=self.url)
net.module_.load_state_dict(state_dict, strict=self.strict)
We can now define a classifier with those callbacks, without having to pass an initialized model, and fit it as before:
clf = EEGClassifier(
"SignalJEPA_PreLocal",
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.AdamW,
optimizer__lr=0.005,
batch_size=16,
callbacks=[
"accuracy",
WeightsLoader(
url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
),
Freezer(patterns="feature_encoder.*"),
],
classes=range(3),
)
_ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)
epoch train_accuracy train_loss valid_acc valid_accuracy valid_loss dur
------- ---------------- ------------ ----------- ---------------- ------------ ------
1 0.3333 1.0997 0.3333 0.3333 1.0987 0.2202
2 0.3362 1.0987 0.3333 0.3333 1.0985 0.2002
3 0.3391 1.0980 0.3448 0.3448 1.0983 0.2046
4 0.3855 1.0963 0.3448 0.3448 1.0978 0.2007
5 0.4986 1.0936 0.4023 0.4023 1.0972 0.2053
6 0.5739 1.0898 0.3563 0.3563 1.0965 0.2028
7 0.5014 1.0852 0.3908 0.3908 1.0954 0.1994
8 0.5652 1.0784 0.4138 0.4138 1.0942 0.1987
9 0.5710 1.0717 0.4023 0.4023 1.0929 0.2000
10 0.6174 1.0629 0.4138 0.4138 1.0914 0.2004
Conclusion and Next Steps#
In this tutorial, we demonstrated how to fine-tune a pre-trained foundation model, Signal-JEPA, for a motor imagery classification task. We now have a basic implementation that can automatically load pre-trained weights and freeze specific layers.
This setup can easily be extended to explore different fine-tuning techniques, base foundation models, and downstream tasks.
References#
Total running time of the script: (0 minutes 11.444 seconds)
Estimated memory usage: 1296 MB