Regression example on fake data

Out:

Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
/home/circleci/project/braindecode/datautil/windowers.py:245: UserWarning: Meaning of `trial_stop_offset_samples`=0 has changed, use `None` to indicate end of trial/recording. Using `None`.
  'Meaning of `trial_stop_offset_samples`=0 has changed, use `None` '
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
  epoch    train_loss    train_neg_root_mean_squared_error    valid_loss    valid_neg_root_mean_squared_error      lr      dur
-------  ------------  -----------------------------------  ------------  -----------------------------------  ------  -------
      1     1088.9221                             -28.1152       40.8822                              -6.3939  0.0006  21.4957
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
      2     1086.1545                             -26.9259       21.5511                              -4.6423  0.0003  20.0886
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
      3     1084.8397                             -26.9904       22.5035                              -4.7438  0.0000  17.7809

<class 'braindecode.regressor.EEGRegressor'>[initialized](
  module_=Sequential(
    (ensuredims): Ensure4d()
    (dimshuffle): Expression(expression=transpose_time_to_spat)
    (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
    (conv_spat): Conv2d(40, 40, kernel_size=(1, 21), stride=(1, 1), bias=False)
    (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin_exp): Expression(expression=square)
    (pool): AvgPool2d(kernel_size=(75, 1), stride=(1, 1), padding=0)
    (pool_nonlin_exp): Expression(expression=safe_log)
    (drop): Dropout(p=0.5, inplace=False)
    (conv_classifier): Conv2d(40, 1, kernel_size=(35, 1), stride=(1, 1), dilation=(15, 1))
    (squeeze): Expression(expression=squeeze_final_output)
  ),
)

# Authors: Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD-3
import numpy as np
import pandas as pd
import torch
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGRegressor
from braindecode.datautil import create_fixed_length_windows
from braindecode.datasets import BaseDataset, BaseConcatDataset
from braindecode.training.losses import CroppedLoss
from braindecode.models import Deep4Net
from braindecode.models import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model, get_output_shape
from braindecode.util import set_random_seeds, create_mne_dummy_raw

model_name = "shallow"  # 'shallow' or 'deep'
n_epochs = 3
seed = 20200220

input_window_samples = 6000
batch_size = 64
cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True

n_chans = 21
# set to how many targets you want to regress (age -> 1, [x, y, z] -> 3)
n_classes = 1

set_random_seeds(seed=seed, cuda=cuda)

# initialize a model, transform to dense and move to gpu
if model_name == "shallow":
    model = ShallowFBCSPNet(
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=40,
        n_filters_spat=40,
        final_conv_length=35,
    )
    optimizer_lr = 0.000625
    optimizer_weight_decay = 0
elif model_name == "deep":
    model = Deep4Net(
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=25,
        n_filters_spat=25,
        stride_before_pool=True,
        n_filters_2=int(n_chans * 2),
        n_filters_3=int(n_chans * (2 ** 2.0)),
        n_filters_4=int(n_chans * (2 ** 3.0)),
        final_conv_length=1,
    )
    optimizer_lr = 0.01
    optimizer_weight_decay = 0.0005
else:
    raise ValueError(f'{model_name} unknown')

new_model = torch.nn.Sequential()
for name, module_ in model.named_children():
    if "softmax" in name:
        continue
    new_model.add_module(name, module_)
model = new_model

if cuda:
    model.cuda()

to_dense_prediction_model(model)
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]

def fake_regression_dataset(n_fake_recs, n_fake_chs, fake_sfreq, fake_duration_s):
    datasets = []
    for i in range(n_fake_recs):
        train_or_eval = "eval" if i == 0 else "train"
        raw, save_fname = create_mne_dummy_raw(
            n_channels=n_fake_chs, n_times=fake_duration_s*fake_sfreq,
            sfreq=fake_sfreq, savedir=None)
        target = np.random.randint(0, 100, n_classes)
        if n_classes == 1:
            target = target[0]
        fake_descrition = pd.Series(
            data=[target, train_or_eval],
            index=["target", "session"])
        base_ds = BaseDataset(raw, fake_descrition, target_name="target")
        datasets.append(base_ds)
    dataset = BaseConcatDataset(datasets)
    return dataset

dataset = fake_regression_dataset(
    n_fake_recs=5, n_fake_chs=21, fake_sfreq=100, fake_duration_s=60)

windows_dataset = create_fixed_length_windows(
    dataset,
    start_offset_samples=0,
    stop_offset_samples=0,
    window_size_samples=input_window_samples,
    window_stride_samples=n_preds_per_input,
    drop_last_window=False,
    drop_bad_windows=True,
)

splits = windows_dataset.split("session")
train_set = splits["train"]
valid_set = splits["eval"]

regressor = EEGRegressor(
    model,
    cropped=True,
    criterion=CroppedLoss,
    criterion__loss_function=torch.nn.functional.mse_loss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),
    optimizer__lr=optimizer_lr,
    optimizer__weight_decay=optimizer_weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    callbacks=[
        "neg_root_mean_squared_error",
        # seems n_epochs -1 leads to desired behavior of lr=0 after end of training?
        ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)

regressor.fit(train_set, y=None, epochs=n_epochs)

Total running time of the script: ( 1 minutes 20.315 seconds)

Estimated memory usage: 111 MB

Gallery generated by Sphinx-Gallery