Training a Braindecode model in PyTorch#

This tutorial shows you how to train a Braindecode model with PyTorch. The data preparation and model instantiation steps are identical to that of the tutorial How to train, test and tune your model

We will use the BCIC IV 2a dataset as a showcase example.

The methods shown can be applied to any standard supervised trial-based decoding setting. This tutorial will include additional parts of code like loading and preprocessing, defining a model, and other details which are not exclusive to this page (compare Cropped Decoding Tutorial). Therefore we will not further elaborate on these parts and you can feel free to skip them.

The goal of this tutorial is to present braindecode in the PyTorch perceptive.

Why should I care about model evaluation?#

Short answer: To produce reliable results!

In machine learning, we usually follow the scheme of splitting the data into two parts, training and testing sets. It sounds like a simple division, right? But the story does not end here.

While developing a ML model you usually have to adjust and tune hyperparameters of your model or pipeline (e.g., number of layers, learning rate, number of epochs). Deep learning models usually have many free parameters; they could be considered complex models with many degrees of freedom. If you kept using the test dataset to evaluate your adjustmentyou would run into data leakage.

This means that if you use the test set to adjust the hyperparameters of your model, the model implicitly learns or memorizes the test set. Therefore, the trained model is no longer independent of the test set (even though it was never used for training explicitly!). If you perform any hyperparameter tuning, you need a third split, the so-called validation set.

This tutorial shows the three basic schemes for training and evaluating the model as well as two methods to tune your hyperparameters.

Warning

You might recognize that the accuracy gets better throughout the experiments of this tutorial. The reason behind that is that we always use the same model with the same parameters in every segment to keep the tutorial short and readable. If you do your own experiments you always have to reinitialize the model before training.

Loading, preprocessing, defining a model, etc.#

Loading the Dataset Structure#

Here, we have a data structure with equal behavior to the Pytorch Dataset.

from braindecode.datasets import MOABBDataset

subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])

Preprocessing, the offline transformation of the raw dataset#

import numpy as np

from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

transforms = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(
        lambda data, factor: np.multiply(data, factor),  # Convert from V to uV
        factor=1e6,
    ),
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Transform the data
preprocess(dataset, transforms, n_jobs=-1)
/home/runner/work/braindecode/braindecode/braindecode/preprocessing/preprocess.py:55: UserWarning: Preprocessing choices with lambda functions cannot be saved.
  warn('Preprocessing choices with lambda functions cannot be saved.')

<braindecode.datasets.moabb.MOABBDataset object at 0x7f4541374df0>

Cut Compute Windows#

from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

Create Pytorch model#

import torch
from braindecode.models import ShallowFBCSPNet
from braindecode.util import set_random_seeds

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_channels = windows_dataset[0][0].shape[0]
input_window_samples = windows_dataset[0][0].shape[1]

# The ShallowFBCSPNet is a `nn.Sequential` model

model = ShallowFBCSPNet(
    n_channels,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length="auto",
)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    model.cuda()
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:23: UserWarning: ShallowFBCSPNet: 'input_window_samples' is depreciated. Use 'n_times' instead.
  warnings.warn(
/home/runner/work/braindecode/braindecode/braindecode/models/base.py:180: UserWarning: LogSoftmax final layer will be removed! Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
  warnings.warn("LogSoftmax final layer will be removed! " +
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
============================================================================================================================================
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 22, 1125]             [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 22, 1125]             [1, 22, 1125, 1]          --                        --
├─Rearrange (dimshuffle): 1-2            [1, 22, 1125, 1]          [1, 1, 1125, 22]          --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 1125, 22]          [1, 40, 1101, 1]          36,240                    --
├─BatchNorm2d (bnorm): 1-4               [1, 40, 1101, 1]          [1, 40, 1101, 1]          80                        --
├─Expression (conv_nonlin_exp): 1-5      [1, 40, 1101, 1]          [1, 40, 1101, 1]          --                        --
├─AvgPool2d (pool): 1-6                  [1, 40, 1101, 1]          [1, 40, 69, 1]            --                        [75, 1]
├─Expression (pool_nonlin_exp): 1-7      [1, 40, 69, 1]            [1, 40, 69, 1]            --                        --
├─Dropout (drop): 1-8                    [1, 40, 69, 1]            [1, 40, 69, 1]            --                        --
├─Sequential (final_layer): 1-9          [1, 40, 69, 1]            [1, 4]                    --                        --
│    └─Conv2d (conv_classifier): 2-1     [1, 40, 69, 1]            [1, 4, 1, 1]              11,044                    [69, 1]
│    └─LogSoftmax (logsoftmax): 2-2      [1, 4, 1, 1]              [1, 4, 1, 1]              --                        --
│    └─Expression (squeeze): 2-3         [1, 4, 1, 1]              [1, 4]                    --                        --
============================================================================================================================================
Total params: 47,364
Trainable params: 47,364
Non-trainable params: 0
Total mult-adds (M): 0.01
============================================================================================================================================
Input size (MB): 0.10
Forward/backward pass size (MB): 0.35
Params size (MB): 0.04
Estimated Total Size (MB): 0.50
============================================================================================================================================

How to train and evaluate your model#

Split dataset into train and test#

We can easily split the dataset using additional info stored in the description attribute, in this case the session column. We select Train for training and test for testing. For other datasets, you might have to choose another column.

Note

No matter which of the three schemes you use, this initial two-fold split into train_set and test_set always remains the same. Remember that you are not allowed to use the test_set during any stage of training or tuning.

splitted = windows_dataset.split("session")
train_set = splitted['0train']  # Session train
test_set = splitted['1test']  # Session evaluation

Option 1: Pure PyTorch training loop#

Pytorch logo

model is an instance of torch.nn.Module, and can as such be trained using PyTorch optimization capabilities. The following training scheme is simple as the dataset is only split into two distinct sets (train_set and test_set). This scheme uses no separate validation split and should only be used for the final evaluation of the (previously!) found hyperparameters configuration.

Warning

If you make any use of the test_set during training (e.g. by using EarlyStopping) there will be data leakage which will make the reported generalization capability/decoding performance of your model less credible.

Warning

The parameter values showcased here for optimizing the network are chosen to make this tutorial fast to run and build. Real-world values would be higher, especially when it comes to n_epochs.

from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader

lr = 0.0625 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 2

The following method runs one training epoch over the dataloader for the given model. It needs a loss function, optimization algorithm, and learning rate updating callback.

from tqdm import tqdm
# Define a method for training one epoch


def train_one_epoch(
        dataloader: DataLoader, model: Module, loss_fn, optimizer,
        scheduler: LRScheduler, epoch: int, device, print_batch_stats=True
):
    model.train()  # Set the model to training mode
    train_loss, correct = 0, 0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader),
                        disable=not print_batch_stats)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()  # update the model weights
        optimizer.zero_grad()

        train_loss += loss.item()
        correct += (pred.argmax(1) == y).sum().item()

        if print_batch_stats:
            progress_bar.set_description(
                f"Epoch {epoch}/{n_epochs}, "
                f"Batch {batch_idx + 1}/{len(dataloader)}, "
                f"Loss: {loss.item():.6f}"
            )

    # Update the learning rate
    scheduler.step()

    correct /= len(dataloader.dataset)
    return train_loss / len(dataloader), correct

Very similarly, the evaluation function loops over the entire dataloader and accumulate the metrics, but doesn’t update the model weights.

@torch.no_grad()
def test_model(
    dataloader: DataLoader, model: Module, loss_fn, print_batch_stats=True
):
    size = len(dataloader.dataset)
    n_batches = len(dataloader)
    model.eval()  # Switch to evaluation mode
    test_loss, correct = 0, 0

    if print_batch_stats:
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    else:
        progress_bar = enumerate(dataloader)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        batch_loss = loss_fn(pred, y).item()

        test_loss += batch_loss
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        if print_batch_stats:
            progress_bar.set_description(
                f"Batch {batch_idx + 1}/{len(dataloader)}, "
                f"Loss: {batch_loss:.6f}"
            )

    test_loss /= n_batches
    correct /= size

    print(
        f"Test Accuracy: {100 * correct:.1f}%, Test Loss: {test_loss:.6f}\n"
    )
    return test_loss, correct


# Define the optimization
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=n_epochs - 1)
# Define the loss function
# We used the NNLoss function, which expects log probabilities as input
# (which is the case for our model output)
loss_fn = torch.nn.NLLLoss()

# train_set and test_set are instances of torch Datasets, and can seamlessly be
# wrapped in data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)

for epoch in range(1, n_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}: ", end="")

    train_loss, train_accuracy = train_one_epoch(
        train_loader, model, loss_fn, optimizer, scheduler, epoch, device,
    )

    test_loss, test_accuracy = test_model(test_loader, model, loss_fn)

    print(
        f"Train Accuracy: {100 * train_accuracy:.2f}%, "
        f"Average Train Loss: {train_loss:.6f}, "
        f"Test Accuracy: {100 * test_accuracy:.1f}%, "
        f"Average Test Loss: {test_loss:.6f}\n"
    )
Epoch 1/2:
  0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1/2, Batch 1/5, Loss: 1.784639:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1/2, Batch 1/5, Loss: 1.784639:  20%|██        | 1/5 [00:00<00:01,  2.81it/s]
Epoch 1/2, Batch 2/5, Loss: 1.497504:  20%|██        | 1/5 [00:00<00:01,  2.81it/s]
Epoch 1/2, Batch 2/5, Loss: 1.497504:  40%|████      | 2/5 [00:00<00:00,  3.24it/s]
Epoch 1/2, Batch 3/5, Loss: 1.544960:  40%|████      | 2/5 [00:00<00:00,  3.24it/s]
Epoch 1/2, Batch 3/5, Loss: 1.544960:  60%|██████    | 3/5 [00:00<00:00,  3.38it/s]
Epoch 1/2, Batch 4/5, Loss: 1.552846:  60%|██████    | 3/5 [00:01<00:00,  3.38it/s]
Epoch 1/2, Batch 4/5, Loss: 1.552846:  80%|████████  | 4/5 [00:01<00:00,  3.38it/s]
Epoch 1/2, Batch 5/5, Loss: 1.648800:  80%|████████  | 4/5 [00:01<00:00,  3.38it/s]
Epoch 1/2, Batch 5/5, Loss: 1.648800: 100%|██████████| 5/5 [00:01<00:00,  4.14it/s]
Epoch 1/2, Batch 5/5, Loss: 1.648800: 100%|██████████| 5/5 [00:01<00:00,  3.69it/s]

  0%|          | 0/5 [00:00<?, ?it/s]
Batch 1/5, Loss: 4.067336:   0%|          | 0/5 [00:00<?, ?it/s]
Batch 2/5, Loss: 4.475679:   0%|          | 0/5 [00:00<?, ?it/s]
Batch 2/5, Loss: 4.475679:  40%|████      | 2/5 [00:00<00:00, 14.75it/s]
Batch 3/5, Loss: 4.224905:  40%|████      | 2/5 [00:00<00:00, 14.75it/s]
Batch 4/5, Loss: 4.333737:  40%|████      | 2/5 [00:00<00:00, 14.75it/s]
Batch 4/5, Loss: 4.333737:  80%|████████  | 4/5 [00:00<00:00, 14.74it/s]
Batch 5/5, Loss: 4.520595:  80%|████████  | 4/5 [00:00<00:00, 14.74it/s]
Batch 5/5, Loss: 4.520595: 100%|██████████| 5/5 [00:00<00:00, 16.31it/s]
Test Accuracy: 25.3%, Test Loss: 4.324450

Train Accuracy: 30.90%, Average Train Loss: 1.605750, Test Accuracy: 25.3%, Average Test Loss: 4.324450

Epoch 2/2:
  0%|          | 0/5 [00:00<?, ?it/s]
Epoch 2/2, Batch 1/5, Loss: 1.460681:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 2/2, Batch 1/5, Loss: 1.460681:  20%|██        | 1/5 [00:00<00:01,  3.68it/s]
Epoch 2/2, Batch 2/5, Loss: 1.178580:  20%|██        | 1/5 [00:00<00:01,  3.68it/s]
Epoch 2/2, Batch 2/5, Loss: 1.178580:  40%|████      | 2/5 [00:00<00:00,  3.50it/s]
Epoch 2/2, Batch 3/5, Loss: 1.232657:  40%|████      | 2/5 [00:00<00:00,  3.50it/s]
Epoch 2/2, Batch 3/5, Loss: 1.232657:  60%|██████    | 3/5 [00:00<00:00,  3.22it/s]
Epoch 2/2, Batch 4/5, Loss: 1.294503:  60%|██████    | 3/5 [00:01<00:00,  3.22it/s]
Epoch 2/2, Batch 4/5, Loss: 1.294503:  80%|████████  | 4/5 [00:01<00:00,  3.38it/s]
Epoch 2/2, Batch 5/5, Loss: 1.144602:  80%|████████  | 4/5 [00:01<00:00,  3.38it/s]
Epoch 2/2, Batch 5/5, Loss: 1.144602: 100%|██████████| 5/5 [00:01<00:00,  4.14it/s]
Epoch 2/2, Batch 5/5, Loss: 1.144602: 100%|██████████| 5/5 [00:01<00:00,  3.77it/s]

  0%|          | 0/5 [00:00<?, ?it/s]
Batch 1/5, Loss: 3.412822:   0%|          | 0/5 [00:00<?, ?it/s]
Batch 2/5, Loss: 3.751318:   0%|          | 0/5 [00:00<?, ?it/s]
Batch 2/5, Loss: 3.751318:  40%|████      | 2/5 [00:00<00:00, 14.65it/s]
Batch 3/5, Loss: 3.565414:  40%|████      | 2/5 [00:00<00:00, 14.65it/s]
Batch 4/5, Loss: 3.667956:  40%|████      | 2/5 [00:00<00:00, 14.65it/s]
Batch 4/5, Loss: 3.667956:  80%|████████  | 4/5 [00:00<00:00, 13.85it/s]
Batch 5/5, Loss: 3.836547:  80%|████████  | 4/5 [00:00<00:00, 13.85it/s]
Batch 5/5, Loss: 3.836547: 100%|██████████| 5/5 [00:00<00:00, 15.58it/s]
Test Accuracy: 24.7%, Test Loss: 3.646811

Train Accuracy: 44.10%, Average Train Loss: 1.262205, Test Accuracy: 24.7%, Average Test Loss: 3.646811

Option 2: Train it with PyTorch Lightning#

Pytorch Lightning logo

Alternatively, lightning provides a nice interface around torch modules which integrates the previous logic.

import lightning as L
from torchmetrics.functional import accuracy


class LitModule(L.LightningModule):
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.loss = torch.nn.NLLLoss()

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.module(x)
        loss = self.loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.module(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y, "multiclass", num_classes=4)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
                                      weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=n_epochs - 1)
        return [optimizer], [scheduler]


# Creating the trainer with max_epochs=2 for demonstration purposes
trainer = L.Trainer(max_epochs=n_epochs)
# Create and train the LightningModule
lit_model = LitModule(model)
trainer.fit(lit_model, train_loader)

# After training, you can test the model using the test DataLoader
trainer.test(dataloaders=test_loader)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/runner/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Missing logger folder: /home/runner/work/braindecode/braindecode/examples/model_building/lightning_logs

  | Name   | Type            | Params
-------------------------------------------
0 | module | ShallowFBCSPNet | 47.4 K
1 | loss   | NLLLoss         | 0
-------------------------------------------
47.4 K    Trainable params
0         Non-trainable params
47.4 K    Total params
0.189     Total estimated model params size (MB)
/home/runner/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

Training: |          | 0/? [00:00<?, ?it/s]
Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 0:  20%|██        | 1/5 [00:00<00:01,  2.72it/s]
Epoch 0:  20%|██        | 1/5 [00:00<00:01,  2.72it/s, v_num=0]
Epoch 0:  40%|████      | 2/5 [00:00<00:00,  3.10it/s, v_num=0]
Epoch 0:  40%|████      | 2/5 [00:00<00:00,  3.10it/s, v_num=0]
Epoch 0:  60%|██████    | 3/5 [00:00<00:00,  3.26it/s, v_num=0]
Epoch 0:  60%|██████    | 3/5 [00:00<00:00,  3.26it/s, v_num=0]
Epoch 0:  80%|████████  | 4/5 [00:01<00:00,  3.36it/s, v_num=0]
Epoch 0:  80%|████████  | 4/5 [00:01<00:00,  3.36it/s, v_num=0]
Epoch 0: 100%|██████████| 5/5 [00:01<00:00,  3.74it/s, v_num=0]
Epoch 0: 100%|██████████| 5/5 [00:01<00:00,  3.74it/s, v_num=0]
Epoch 0: 100%|██████████| 5/5 [00:01<00:00,  3.74it/s, v_num=0]
Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s, v_num=0]
Epoch 1:   0%|          | 0/5 [00:00<?, ?it/s, v_num=0]
Epoch 1:  20%|██        | 1/5 [00:00<00:01,  3.48it/s, v_num=0]
Epoch 1:  20%|██        | 1/5 [00:00<00:01,  3.48it/s, v_num=0]
Epoch 1:  40%|████      | 2/5 [00:00<00:00,  3.52it/s, v_num=0]
Epoch 1:  40%|████      | 2/5 [00:00<00:00,  3.52it/s, v_num=0]
Epoch 1:  60%|██████    | 3/5 [00:00<00:00,  3.57it/s, v_num=0]
Epoch 1:  60%|██████    | 3/5 [00:00<00:00,  3.57it/s, v_num=0]
Epoch 1:  80%|████████  | 4/5 [00:01<00:00,  3.56it/s, v_num=0]
Epoch 1:  80%|████████  | 4/5 [00:01<00:00,  3.55it/s, v_num=0]
Epoch 1: 100%|██████████| 5/5 [00:01<00:00,  3.93it/s, v_num=0]
Epoch 1: 100%|██████████| 5/5 [00:01<00:00,  3.93it/s, v_num=0]
Epoch 1: 100%|██████████| 5/5 [00:01<00:00,  3.93it/s, v_num=0]`Trainer.fit` stopped: `max_epochs=2` reached.

Epoch 1: 100%|██████████| 5/5 [00:01<00:00,  3.92it/s, v_num=0]
/home/runner/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:145: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.
Restoring states from the checkpoint path at /home/runner/work/braindecode/braindecode/examples/model_building/lightning_logs/version_0/checkpoints/epoch=1-step=10.ckpt
Loaded model weights from the checkpoint at /home/runner/work/braindecode/braindecode/examples/model_building/lightning_logs/version_0/checkpoints/epoch=1-step=10.ckpt
/home/runner/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.

Testing: |          | 0/? [00:00<?, ?it/s]
Testing:   0%|          | 0/5 [00:00<?, ?it/s]
Testing DataLoader 0:   0%|          | 0/5 [00:00<?, ?it/s]
Testing DataLoader 0:  20%|██        | 1/5 [00:00<00:00,  9.95it/s]
Testing DataLoader 0:  40%|████      | 2/5 [00:00<00:00, 11.29it/s]
Testing DataLoader 0:  60%|██████    | 3/5 [00:00<00:00, 12.20it/s]
Testing DataLoader 0:  80%|████████  | 4/5 [00:00<00:00, 12.71it/s]
Testing DataLoader 0: 100%|██████████| 5/5 [00:00<00:00, 14.18it/s]
Testing DataLoader 0: 100%|██████████| 5/5 [00:00<00:00, 13.93it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                  0.34375
        test_loss           1.5098872184753418
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

[{'test_acc': 0.34375, 'test_loss': 1.5098872184753418}]

Total running time of the script: (0 minutes 11.348 seconds)

Estimated memory usage: 289 MB

Gallery generated by Sphinx-Gallery