Regression example on fake dataΒΆ

Out:

Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
1 matching events found
No baseline correction applied
Adding metadata with 4 columns
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
1 matching events found
No baseline correction applied
Adding metadata with 4 columns
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
1 matching events found
No baseline correction applied
Adding metadata with 4 columns
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
1 matching events found
No baseline correction applied
Adding metadata with 4 columns
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
1 matching events found
No baseline correction applied
Adding metadata with 4 columns
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
/home/circleci/.local/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:143: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  epoch    train_loss    train_neg_root_mean_squared_error    valid_loss    valid_neg_root_mean_squared_error     dur
-------  ------------  -----------------------------------  ------------  -----------------------------------  ------
      1     1088.8221                             -28.1520       41.5739                              -6.4478  0.8293
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
/home/circleci/.local/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:143: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
      2     1086.1565                             -25.7864        8.2148                              -2.8661  0.5019
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
/home/circleci/.local/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:143: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
      3     1083.3179                             -24.8156        1.5841                              -1.2586  0.5435

<class 'braindecode.regressor.EEGRegressor'>[initialized](
  module_=Sequential(
    (ensuredims): Ensure4d()
    (dimshuffle): Expression(expression=transpose_time_to_spat)
    (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
    (conv_spat): Conv2d(40, 40, kernel_size=(1, 21), stride=(1, 1), bias=False)
    (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin_exp): Expression(expression=square)
    (pool): AvgPool2d(kernel_size=(75, 1), stride=(1, 1), padding=0)
    (pool_nonlin_exp): Expression(expression=safe_log)
    (drop): Dropout(p=0.5, inplace=False)
    (conv_classifier): Conv2d(40, 1, kernel_size=(35, 1), stride=(1, 1), dilation=(15, 1))
    (squeeze): Expression(expression=squeeze_final_output)
  ),
)

# Authors: Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD-3
import numpy as np
import pandas as pd
import torch
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGRegressor
from braindecode.datautil import create_fixed_length_windows
from braindecode.datasets import BaseDataset, BaseConcatDataset
from braindecode.training.losses import CroppedLoss
from braindecode.models import Deep4Net
from braindecode.models import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model, get_output_shape
from braindecode.util import set_random_seeds, create_mne_dummy_raw

model_name = "shallow"  # 'shallow' or 'deep'
n_epochs = 3
seed = 20200220

input_window_samples = 6000
batch_size = 64
cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True

n_chans = 21
# set to how many targets you want to regress (age -> 1, [x, y, z] -> 3)
n_classes = 1

set_random_seeds(seed=seed, cuda=cuda)

# initialize a model, transform to dense and move to gpu
if model_name == "shallow":
    model = ShallowFBCSPNet(
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=40,
        n_filters_spat=40,
        final_conv_length=35,
    )
    optimizer_lr = 0.000625
    optimizer_weight_decay = 0
elif model_name == "deep":
    model = Deep4Net(
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=25,
        n_filters_spat=25,
        stride_before_pool=True,
        n_filters_2=int(n_chans * 2),
        n_filters_3=int(n_chans * (2 ** 2.0)),
        n_filters_4=int(n_chans * (2 ** 3.0)),
        final_conv_length=1,
    )
    optimizer_lr = 0.01
    optimizer_weight_decay = 0.0005
else:
    raise ValueError(f'{model_name} unknown')

new_model = torch.nn.Sequential()
for name, module_ in model.named_children():
    if "softmax" in name:
        continue
    new_model.add_module(name, module_)
model = new_model

if cuda:
    model.cuda()

to_dense_prediction_model(model)
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]

def fake_regression_dataset(n_fake_recs, n_fake_chs, fake_sfreq, fake_duration_s):
    datasets = []
    for i in range(n_fake_recs):
        train_or_eval = "eval" if i == 0 else "train"
        raw, save_fname = create_mne_dummy_raw(
            n_channels=n_fake_chs, n_times=fake_duration_s*fake_sfreq,
            sfreq=fake_sfreq, savedir=None)
        target = np.random.randint(0, 100, n_classes)
        if n_classes == 1:
            target = target[0]
        fake_descrition = pd.Series(
            data=[target, train_or_eval],
            index=["target", "session"])
        base_ds = BaseDataset(raw, fake_descrition, target_name="target")
        datasets.append(base_ds)
    dataset = BaseConcatDataset(datasets)
    return dataset

dataset = fake_regression_dataset(
    n_fake_recs=5, n_fake_chs=21, fake_sfreq=100, fake_duration_s=60)

windows_dataset = create_fixed_length_windows(
    dataset,
    start_offset_samples=0,
    stop_offset_samples=0,
    window_size_samples=input_window_samples,
    window_stride_samples=n_preds_per_input,
    drop_last_window=False,
    drop_bad_windows=True,
)

splits = windows_dataset.split("session")
train_set = splits["train"]
valid_set = splits["eval"]

regressor = EEGRegressor(
    model,
    cropped=True,
    criterion=CroppedLoss,
    criterion__loss_function=torch.nn.functional.mse_loss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),
    optimizer__lr=optimizer_lr,
    optimizer__weight_decay=optimizer_weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    callbacks=[
        "neg_root_mean_squared_error",
        # seems n_epochs -1 leads to desired behavior of lr=0 after end of training?
        ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)

regressor.fit(train_set, y=None, epochs=n_epochs)

Total running time of the script: ( 0 minutes 2.802 seconds)

Estimated memory usage: 272 MB

Gallery generated by Sphinx-Gallery