"""Preprocessors that work on Raw or Epochs objects.
"""
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
# Lukas Gemein <l.gemein@gmail.com>
# Simon Brandt <simonbrandt@protonmail.com>
# David Sabbagh <dav.sabbagh@gmail.com>
# Bruno Aristimunha <b.aristimunha@gmail.com>
#
# License: BSD (3-clause)
from warnings import warn
from functools import partial
from collections.abc import Iterable
import numpy as np
import pandas as pd
from mne import create_info
from sklearn.utils import deprecated
from joblib import Parallel, delayed
from braindecode.datasets.base import BaseConcatDataset, BaseDataset, WindowsDataset
from braindecode.datautil.serialization import (
load_concat_dataset, _check_save_dir_empty)
[docs]class Preprocessor(object):
"""Preprocessor for an MNE Raw or Epochs object.
Applies the provided preprocessing function to the data of a Raw or Epochs
object.
If the function is provided as a string, the method with that name will be
used (e.g., 'pick_channels', 'filter', etc.).
If it is provided as a callable and `apply_on_array` is True, the
`apply_function` method of Raw and Epochs object will be used to apply the
function on the internal arrays of Raw and Epochs.
If `apply_on_array` is False, the callable must directly modify the Raw or
Epochs object (e.g., by calling its method(s) or modifying its attributes).
Parameters
----------
fn: str or callable
If str, the Raw/Epochs object must have a method with that name.
If callable, directly apply the callable to the object.
apply_on_array : bool
Ignored if `fn` is not a callable. If True, the `apply_function` of Raw
and Epochs object will be used to run `fn` on the underlying arrays
directly. If False, `fn` must directly modify the Raw or Epochs object.
kwargs:
Keyword arguments to be forwarded to the MNE function.
"""
def __init__(self, fn, *, apply_on_array=True, **kwargs):
if hasattr(fn, '__name__') and fn.__name__ == '<lambda>':
warn('Preprocessing choices with lambda functions cannot be saved.')
if callable(fn) and apply_on_array:
channel_wise = kwargs.pop('channel_wise', False)
picks = kwargs.pop('picks', None)
n_jobs = kwargs.pop('n_jobs', 1)
kwargs = dict(fun=partial(fn, **kwargs), channel_wise=channel_wise,
picks=picks, n_jobs=n_jobs)
fn = 'apply_function'
self.fn = fn
self.kwargs = kwargs
[docs] def apply(self, raw_or_epochs):
try:
self._try_apply(raw_or_epochs)
except RuntimeError:
# Maybe the function needs the data to be loaded and the data was
# not loaded yet. Not all MNE functions need data to be loaded,
# most importantly the 'crop' function can be lazily applied
# without preloading data which can make the overall preprocessing
# pipeline substantially faster.
raw_or_epochs.load_data()
self._try_apply(raw_or_epochs)
def _try_apply(self, raw_or_epochs):
if callable(self.fn):
self.fn(raw_or_epochs, **self.kwargs)
else:
if not hasattr(raw_or_epochs, self.fn):
raise AttributeError(
f'MNE object does not have a {self.fn} method.')
getattr(raw_or_epochs, self.fn)(**self.kwargs)
[docs]def preprocess(concat_ds, preprocessors, save_dir=None, overwrite=False,
n_jobs=None):
"""Apply preprocessors to a concat dataset.
Parameters
----------
concat_ds: BaseConcatDataset
A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
preprocessors: list(Preprocessor)
List of Preprocessor objects to apply to the dataset.
save_dir : str | None
If a string, the preprocessed data will be saved under the specified
directory and the datasets in ``concat_ds`` will be reloaded with
`preload=False`.
overwrite : bool
When `save_dir` is provided, controls whether to delete the old
subdirectories that will be written to under `save_dir`. If False and
the corresponding subdirectories already exist, a ``FileExistsError``
will be raised.
n_jobs : int | None
Number of jobs for parallel execution.
Returns
-------
BaseConcatDataset:
Preprocessed dataset.
"""
# In case of serialization, make sure directory is available before
# preprocessing
if save_dir is not None and not overwrite:
_check_save_dir_empty(save_dir)
if not isinstance(preprocessors, Iterable):
raise ValueError(
'preprocessors must be a list of Preprocessor objects.')
for elem in preprocessors:
assert hasattr(elem, 'apply'), (
'Preprocessor object needs an `apply` method.')
list_of_ds = Parallel(n_jobs=n_jobs)(
delayed(_preprocess)(ds, i, preprocessors, save_dir, overwrite)
for i, ds in enumerate(concat_ds.datasets))
if save_dir is not None: # Reload datasets and replace in concat_ds
concat_ds_reloaded = load_concat_dataset(
save_dir, preload=False, target_name=None)
_replace_inplace(concat_ds, concat_ds_reloaded)
else:
if n_jobs is None or n_jobs == 1: # joblib did not make copies, the
# preprocessing happened in-place
# Recompute cumulative sizes as transforms might have changed them
concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
else: # joblib made copies
_replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))
return concat_ds
def _replace_inplace(concat_ds, new_concat_ds):
"""Replace subdatasets and preproc_kwargs of a BaseConcatDataset inplace.
Parameters
----------
concat_ds : BaseConcatDataset
Dataset to modify inplace.
new_concat_ds : BaseConcatDataset
Dataset to use to modify ``concat_ds``.
"""
if len(concat_ds.datasets) != len(new_concat_ds.datasets):
raise ValueError('Both inputs must have the same length.')
for i in range(len(new_concat_ds.datasets)):
concat_ds.datasets[i] = new_concat_ds.datasets[i]
concat_kind = 'raw' if hasattr(concat_ds.datasets[0], 'raw') else 'window'
preproc_kwargs_attr = concat_kind + '_preproc_kwargs'
if hasattr(new_concat_ds, preproc_kwargs_attr):
setattr(concat_ds, preproc_kwargs_attr,
getattr(new_concat_ds, preproc_kwargs_attr))
def _preprocess(ds, ds_index, preprocessors, save_dir=None, overwrite=False):
"""Apply preprocessor(s) to Raw or Epochs object.
Parameters
----------
ds: BaseDataset | WindowsDataset
Dataset object to preprocess.
ds_index : int
Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
is None.
preprocessors: list(Preprocessor)
List of preprocessors to apply to the dataset.
save_dir : str | None
If provided, save the preprocessed BaseDataset in the
specified directory.
overwrite : bool
If True, overwrite existing file with the same name.
"""
def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
for preproc in preprocessors:
preproc.apply(raw_or_epochs)
if hasattr(ds, 'raw'):
_preprocess_raw_or_epochs(ds.raw, preprocessors)
elif hasattr(ds, 'windows'):
_preprocess_raw_or_epochs(ds.windows, preprocessors)
else:
raise ValueError(
'Can only preprocess concatenation of BaseDataset or '
'WindowsDataset, with either a `raw` or `windows` attribute.')
# Store preprocessing keyword arguments in the dataset
_set_preproc_kwargs(ds, preprocessors)
if save_dir is not None:
concat_ds = BaseConcatDataset([ds])
concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
else:
return ds
def _get_preproc_kwargs(preprocessors):
preproc_kwargs = []
for p in preprocessors:
# in case of a mne function, fn is a str, kwargs is a dict
func_name = p.fn
func_kwargs = p.kwargs
# in case of another function
# if apply_on_array=False
if callable(p.fn):
func_name = p.fn.__name__
# if apply_on_array=True
else:
if 'fun' in p.fn:
func_name = p.kwargs['fun'].func.__name__
func_kwargs = p.kwargs['fun'].keywords
preproc_kwargs.append((func_name, func_kwargs))
return preproc_kwargs
def _set_preproc_kwargs(ds, preprocessors):
"""Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
Parameters
----------
ds : BaseDataset | WindowsDataset
Dataset in which to record preprocessing keyword arguments.
preprocessors : list
List of preprocessors.
"""
preproc_kwargs = _get_preproc_kwargs(preprocessors)
if isinstance(ds, WindowsDataset):
kind = 'window'
elif isinstance(ds, BaseDataset):
kind = 'raw'
else:
raise TypeError(
f'ds must be a BaseDataset or a WindowsDataset, got {type(ds)}')
setattr(ds, kind + '_preproc_kwargs', preproc_kwargs)
[docs]def exponential_moving_standardize(
data, factor_new=0.001, init_block_size=None, eps=1e-4
):
r"""Perform exponential moving standardization.
Compute the exponental moving mean :math:`m_t` at time `t` as
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
Then, compute exponential moving variance :math:`v_t` at time `t` as
:math:`v_t=\mathrm{factornew} \cdot (m_t - x_t)^2 + (1 - \mathrm{factornew}) \cdot v_{t-1}`.
Finally, standardize the data point :math:`x_t` at time `t` as:
:math:`x'_t=(x_t - m_t) / max(\sqrt{->v_t}, eps)`.
Parameters
----------
data: np.ndarray (n_channels, n_times)
factor_new: float
init_block_size: int
Standardize data before to this index with regular standardization.
eps: float
Stabilizer for division by zero variance.
Returns
-------
standardized: np.ndarray (n_channels, n_times)
Standardized data.
"""
data = data.T
df = pd.DataFrame(data)
meaned = df.ewm(alpha=factor_new).mean()
demeaned = df - meaned
squared = demeaned * demeaned
square_ewmed = squared.ewm(alpha=factor_new).mean()
standardized = demeaned / np.maximum(eps, np.sqrt(np.array(square_ewmed)))
standardized = np.array(standardized)
if init_block_size is not None:
i_time_axis = 0
init_mean = np.mean(
data[0:init_block_size], axis=i_time_axis, keepdims=True
)
init_std = np.std(
data[0:init_block_size], axis=i_time_axis, keepdims=True
)
init_block_standardized = (data[0:init_block_size] - init_mean) / np.maximum(eps, init_std)
standardized[0:init_block_size] = init_block_standardized
return standardized.T
[docs]def exponential_moving_demean(data, factor_new=0.001, init_block_size=None):
r"""Perform exponential moving demeanining.
Compute the exponental moving mean :math:`m_t` at time `t` as
:math:`m_t=\mathrm{factornew} \cdot mean(x_t) + (1 - \mathrm{factornew}) \cdot m_{t-1}`.
Deman the data point :math:`x_t` at time `t` as:
:math:`x'_t=(x_t - m_t)`.
Parameters
----------
data: np.ndarray (n_channels, n_times)
factor_new: float
init_block_size: int
Demean data before to this index with regular demeaning.
Returns
-------
demeaned: np.ndarray (n_channels, n_times)
Demeaned data.
"""
data = data.T
df = pd.DataFrame(data)
meaned = df.ewm(alpha=factor_new).mean()
demeaned = df - meaned
demeaned = np.array(demeaned)
if init_block_size is not None:
i_time_axis = 0
init_mean = np.mean(
data[0:init_block_size], axis=i_time_axis, keepdims=True
)
demeaned[0:init_block_size] = data[0:init_block_size] - init_mean
return demeaned.T
[docs]@deprecated(extra='will be removed in 0.8.0. Use numpy.multiply inside a lambda function instead.')
def scale(data, factor):
"""Scale continuous or windowed data in-place
Parameters
----------
data: np.ndarray (n_channels x n_times) or (n_windows x n_channels x
n_times)
continuous or windowed signal
factor: float
multiplication factor
Returns
-------
scaled: np.ndarray (n_channels x n_times) or (n_windows x n_channels x
n_times)
normalized continuous or windowed data
..note:
If this function is supposed to preprocess continuous data, it should be
given to raw.apply_function().
"""
scaled = np.multiply(data, factor)
# TODO: the overriding of protected '_data' should be implemented in the
# TODO: dataset when transforms are applied to windows
if hasattr(data, '_data'):
data._data = scaled
return scaled
[docs]def filterbank(raw, frequency_bands, drop_original_signals=True,
order_by_frequency_band=False, **mne_filter_kwargs):
"""Applies multiple bandpass filters to the signals in raw. The raw will be
modified in-place and number of channels in raw will be updated to
len(frequency_bands) * len(raw.ch_names) (-len(raw.ch_names) if
drop_original_signals).
Parameters
----------
raw: mne.io.Raw
The raw signals to be filtered.
frequency_bands: list(tuple)
The frequency bands to be filtered for (e.g. [(4, 8), (8, 13)]).
drop_original_signals: bool
Whether to drop the original unfiltered signals
order_by_frequency_band: bool
If True will return channels odered by frequency bands, so if there
are channels Cz, O1 and filterbank ranges [(4,8), (8,13)], returned
channels will be [Cz_4-8, O1_4-8, Cz_8-13, O1_8-13]. If False, order
will be [Cz_4-8, Cz_8-13, O1_4-8, O1_8-13].
mne_filter_kwargs: dict
Keyword arguments for filtering supported by mne.io.Raw.filter().
Please refer to mne for a detailed explanation.
"""
if not frequency_bands:
raise ValueError(f"Expected at least one frequency band, got"
f" {frequency_bands}")
if not all([len(ch_name) < 8 for ch_name in raw.ch_names]):
warn("Try to use shorter channel names, since frequency band "
"annotation requires an estimated 4-8 chars depending on the "
"frequency ranges. Will truncate to 15 chars (mne max).")
original_ch_names = raw.ch_names
all_filtered = []
for (l_freq, h_freq) in frequency_bands:
filtered = raw.copy()
filtered.filter(l_freq=l_freq, h_freq=h_freq, **mne_filter_kwargs)
# mne automatically changes the highpass/lowpass info values
# when applying filters and channels cant be added if they have
# different such parameters. Not needed when making picks as
# high pass is not modified by filter if pick is specified
ch_names = filtered.info.ch_names
ch_types = filtered.info.get_channel_types()
sampling_freq = filtered.info['sfreq']
info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sampling_freq)
filtered.info = info
# add frequency band annotation to channel names
# truncate to a max of 15 characters, since mne does not allow for more
filtered.rename_channels({
old_name: (old_name + f"_{l_freq}-{h_freq}")[-15:]
for old_name in filtered.ch_names})
all_filtered.append(filtered)
raw.add_channels(all_filtered)
if not order_by_frequency_band:
# order channels by name and not by frequency band:
# index the list with a stepsize of the number of channels for each of
# the original channels
chs_by_freq_band = []
for i in range(len(original_ch_names)):
chs_by_freq_band.extend(raw.ch_names[i::len(original_ch_names)])
raw.reorder_channels(chs_by_freq_band)
if drop_original_signals:
raw.drop_channels(original_ch_names)