Source code for braindecode.datasets.xy
# Authors: Lukas Gemein <l.gemein@gmail.com>
# Robin Schirrmeister <robintibor@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
import pandas as pd
import logging
import mne
from .base import BaseDataset, BaseConcatDataset
log = logging.getLogger(__name__)
[docs]def create_from_X_y(
X, y, drop_last_window, sfreq, ch_names=None, window_size_samples=None,
window_stride_samples=None):
"""Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
decoding with skorch and braindecode, where X is a list of pre-cut trials
and y are corresponding targets.
Parameters
----------
X: array-like
list of pre-cut trials as n_trials x n_channels x n_times
y: array-like
targets corresponding to the trials
drop_last_window: bool
whether or not have a last overlapping window, when
windows/windows do not equally divide the continuous signal
sfreq: float
Sampling frequency of signals.
ch_names: array-like
Names of the channels.
window_size_samples: int
window size
window_stride_samples: int
stride between windows
Returns
-------
windows_datasets: BaseConcatDataset
X and y transformed to a dataset format that is compatible with skorch
and braindecode
"""
# Prevent circular import
from ..preprocessing.windowers import (
create_fixed_length_windows, )
n_samples_per_x = []
base_datasets = []
if ch_names is None:
ch_names = [str(i) for i in range(X.shape[1])]
log.info(f"No channel names given, set to 0-{X.shape[1]}).")
for x, target in zip(X, y):
n_samples_per_x.append(x.shape[1])
info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
raw = mne.io.RawArray(x, info)
base_dataset = BaseDataset(raw, pd.Series({"target": target}),
target_name="target")
base_datasets.append(base_dataset)
base_datasets = BaseConcatDataset(base_datasets)
if window_size_samples is None and window_stride_samples is None:
if not len(np.unique(n_samples_per_x)) == 1:
raise ValueError("if 'window_size_samples' and "
"'window_stride_samples' are None, "
"all trials have to have the same length")
window_size_samples = n_samples_per_x[0]
window_stride_samples = n_samples_per_x[0]
windows_datasets = create_fixed_length_windows(
base_datasets,
start_offset_samples=0,
stop_offset_samples=None,
window_size_samples=window_size_samples,
window_stride_samples=window_stride_samples,
drop_last_window=drop_last_window
)
return windows_datasets