Regression example on fake dataΒΆ

Out:

Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
Creating RawArray with float64 data, n_channels=21, n_times=6000
    Range : 0 ... 5999 =      0.000 ...    59.990 secs
Ready.
/home/circleci/project/braindecode/datautil/windowers.py:244: UserWarning: Meaning of `trial_stop_offset_samples`=0 has changed, use `None` to indicate end of trial/recording. Using `None`.
  'Meaning of `trial_stop_offset_samples`=0 has changed, use `None` '
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
Replacing existing metadata with 4 columns
1 matching events found
No baseline correction applied
0 projection items activated
Loading data for 1 events and 6000 original time points ...
0 bad epochs dropped
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
/home/circleci/.local/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:143: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  epoch    train_loss    train_neg_root_mean_squared_error    valid_loss    valid_neg_root_mean_squared_error      lr     dur
-------  ------------  -----------------------------------  ------------  -----------------------------------  ------  ------
      1     1088.8221                             -28.1714       41.9291                              -6.4753  0.0006  1.7092
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
/home/circleci/.local/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:143: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
      2     1086.0312                             -27.0026       22.6491                              -4.7591  0.0003  1.7439
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
Loading data for 1 events and 6000 original time points ...
/home/circleci/.local/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:143: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
      3     1084.8628                             -27.0687       23.6458                              -4.8627  0.0000  0.5531

<class 'braindecode.regressor.EEGRegressor'>[initialized](
  module_=Sequential(
    (ensuredims): Ensure4d()
    (dimshuffle): Expression(expression=transpose_time_to_spat)
    (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
    (conv_spat): Conv2d(40, 40, kernel_size=(1, 21), stride=(1, 1), bias=False)
    (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin_exp): Expression(expression=square)
    (pool): AvgPool2d(kernel_size=(75, 1), stride=(1, 1), padding=0)
    (pool_nonlin_exp): Expression(expression=safe_log)
    (drop): Dropout(p=0.5, inplace=False)
    (conv_classifier): Conv2d(40, 1, kernel_size=(35, 1), stride=(1, 1), dilation=(15, 1))
    (squeeze): Expression(expression=squeeze_final_output)
  ),
)

# Authors: Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD-3
import numpy as np
import pandas as pd
import torch
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGRegressor
from braindecode.datautil import create_fixed_length_windows
from braindecode.datasets import BaseDataset, BaseConcatDataset
from braindecode.training.losses import CroppedLoss
from braindecode.models import Deep4Net
from braindecode.models import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model, get_output_shape
from braindecode.util import set_random_seeds, create_mne_dummy_raw

model_name = "shallow"  # 'shallow' or 'deep'
n_epochs = 3
seed = 20200220

input_window_samples = 6000
batch_size = 64
cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True

n_chans = 21
# set to how many targets you want to regress (age -> 1, [x, y, z] -> 3)
n_classes = 1

set_random_seeds(seed=seed, cuda=cuda)

# initialize a model, transform to dense and move to gpu
if model_name == "shallow":
    model = ShallowFBCSPNet(
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=40,
        n_filters_spat=40,
        final_conv_length=35,
    )
    optimizer_lr = 0.000625
    optimizer_weight_decay = 0
elif model_name == "deep":
    model = Deep4Net(
        in_chans=n_chans,
        n_classes=n_classes,
        input_window_samples=input_window_samples,
        n_filters_time=25,
        n_filters_spat=25,
        stride_before_pool=True,
        n_filters_2=int(n_chans * 2),
        n_filters_3=int(n_chans * (2 ** 2.0)),
        n_filters_4=int(n_chans * (2 ** 3.0)),
        final_conv_length=1,
    )
    optimizer_lr = 0.01
    optimizer_weight_decay = 0.0005
else:
    raise ValueError(f'{model_name} unknown')

new_model = torch.nn.Sequential()
for name, module_ in model.named_children():
    if "softmax" in name:
        continue
    new_model.add_module(name, module_)
model = new_model

if cuda:
    model.cuda()

to_dense_prediction_model(model)
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]

def fake_regression_dataset(n_fake_recs, n_fake_chs, fake_sfreq, fake_duration_s):
    datasets = []
    for i in range(n_fake_recs):
        train_or_eval = "eval" if i == 0 else "train"
        raw, save_fname = create_mne_dummy_raw(
            n_channels=n_fake_chs, n_times=fake_duration_s*fake_sfreq,
            sfreq=fake_sfreq, savedir=None)
        target = np.random.randint(0, 100, n_classes)
        if n_classes == 1:
            target = target[0]
        fake_descrition = pd.Series(
            data=[target, train_or_eval],
            index=["target", "session"])
        base_ds = BaseDataset(raw, fake_descrition, target_name="target")
        datasets.append(base_ds)
    dataset = BaseConcatDataset(datasets)
    return dataset

dataset = fake_regression_dataset(
    n_fake_recs=5, n_fake_chs=21, fake_sfreq=100, fake_duration_s=60)

windows_dataset = create_fixed_length_windows(
    dataset,
    start_offset_samples=0,
    stop_offset_samples=0,
    window_size_samples=input_window_samples,
    window_stride_samples=n_preds_per_input,
    drop_last_window=False,
    drop_bad_windows=True,
)

splits = windows_dataset.split("session")
train_set = splits["train"]
valid_set = splits["eval"]

regressor = EEGRegressor(
    model,
    cropped=True,
    criterion=CroppedLoss,
    criterion__loss_function=torch.nn.functional.mse_loss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),
    optimizer__lr=optimizer_lr,
    optimizer__weight_decay=optimizer_weight_decay,
    iterator_train__shuffle=True,
    batch_size=batch_size,
    callbacks=[
        "neg_root_mean_squared_error",
        # seems n_epochs -1 leads to desired behavior of lr=0 after end of training?
        ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)

regressor.fit(train_set, y=None, epochs=n_epochs)

Total running time of the script: ( 0 minutes 5.678 seconds)

Estimated memory usage: 106 MB

Gallery generated by Sphinx-Gallery