# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
# Robin Schirrmeister <robintibor@gmail.com>
# Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD (3-clause)
import warnings
import numpy as np
from sklearn.metrics import get_scorer
from skorch.callbacks import EpochTimer, BatchScoring, PrintLog, EpochScoring
from skorch.classifier import NeuralNet
from skorch.classifier import NeuralNetClassifier
from skorch.utils import train_loss_score, valid_loss_score, noop, to_numpy
import torch
from .training.scoring import (PostEpochTrainScoring,
CroppedTrialEpochScoring,
CroppedTimeSeriesEpochScoring,
predict_trials)
from .util import ThrowAwayIndexLoader, update_estimator_docstring
[docs]class EEGClassifier(NeuralNetClassifier):
doc = """Classifier that does not assume softmax activation.
Calls loss function directly without applying log or anything.
Parameters
----------
cropped: bool (default=False)
Defines whether torch model passed to this class is cropped or not.
Currently used for callbacks definition.
callbacks: None or list of strings or list of Callback instances (default=None)
More callbacks, in addition to those returned by
``get_default_callbacks``. Each callback should inherit from
:class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a
list of strings specifying `sklearn` scoring functions (for scoring
functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)
or a list of callbacks where the callback names are inferred from the
class name. Name conflicts are resolved by appending a count suffix
starting with 1, e.g. ``EpochScoring_1``. Alternatively,
a tuple ``(name, callback)`` can be passed, where ``name``
should be unique. Callbacks may or may not be instantiated.
The callback name can be used to set parameters on specific
callbacks (e.g., for the callback with name ``'print_log'``, use
``net.set_params(callbacks__print_log__keys_ignored=['epoch',
'train_loss'])``).
iterator_train__shuffle: bool (default=True)
Defines whether train dataset will be shuffled. As skorch does not
shuffle the train dataset by default this one overwrites this option.
aggregate_predictions: bool (default=True)
Whether to average cropped predictions to obtain window predictions. Used only in the
cropped mode.
""" # noqa: E501
__doc__ = update_estimator_docstring(NeuralNetClassifier, doc)
def __init__(self, *args, cropped=False, callbacks=None,
iterator_train__shuffle=True, iterator_train__drop_last=True,
aggregate_predictions=True, **kwargs):
self.cropped = cropped
self.aggregate_predictions = aggregate_predictions
self._last_window_inds_ = None
super().__init__(*args,
callbacks=callbacks,
iterator_train__shuffle=iterator_train__shuffle,
iterator_train__drop_last=iterator_train__drop_last,
**kwargs)
def _yield_callbacks(self):
# Here we parse the callbacks supplied as strings,
# e.g. 'accuracy', to the callbacks skorch expects
for name, cb, named_by_user in super()._yield_callbacks():
if name == 'str':
train_cb, valid_cb = self._parse_str_callback(cb)
yield train_cb
if self.train_split is not None:
yield valid_cb
else:
yield name, cb, named_by_user
def _parse_str_callback(self, cb_supplied_name):
scoring = get_scorer(cb_supplied_name)
scoring_name = scoring._score_func.__name__
assert scoring_name.endswith(
('_score', '_error', '_deviance', '_loss'))
if (scoring_name.endswith('_score') or
cb_supplied_name.startswith('neg_')):
lower_is_better = False
else:
lower_is_better = True
train_name = f'train_{cb_supplied_name}'
valid_name = f'valid_{cb_supplied_name}'
if self.cropped:
# TODO: use CroppedTimeSeriesEpochScoring when time series target
# In case of cropped decoding we are using braindecode
# specific scoring created for cropped decoding
train_scoring = CroppedTrialEpochScoring(
cb_supplied_name, lower_is_better, on_train=True, name=train_name
)
valid_scoring = CroppedTrialEpochScoring(
cb_supplied_name, lower_is_better, on_train=False, name=valid_name
)
else:
train_scoring = PostEpochTrainScoring(
cb_supplied_name, lower_is_better, name=train_name
)
valid_scoring = EpochScoring(
cb_supplied_name, lower_is_better, on_train=False, name=valid_name
)
named_by_user = True
train_valid_callbacks = [
(train_name, train_scoring, named_by_user),
(valid_name, valid_scoring, named_by_user)
]
return train_valid_callbacks
# pylint: disable=arguments-differ
[docs] def get_loss(self, y_pred, y_true, *args, **kwargs):
"""Return the loss for this batch by calling NeuralNet get_loss.
Parameters
----------
y_pred : torch tensor
Predicted target values
y_true : torch tensor
True target values.
X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
* numpy arrays
* torch tensors
* pandas DataFrame or Series
* scipy sparse CSR matrices
* a dictionary of the former three
* a list/tuple of the former three
* a Dataset
If this doesn't work with your data, you have to pass a
``Dataset`` that can deal with the data.
training : bool (default=False)
Whether train mode should be used or not.
Returns
-------
loss : float
The loss value.
"""
return NeuralNet.get_loss(self, y_pred, y_true, *args, **kwargs)
[docs] def get_iterator(self, dataset, training=False, drop_index=True):
iterator = super().get_iterator(dataset, training=training)
if drop_index:
return ThrowAwayIndexLoader(self, iterator, is_regression=False)
else:
return iterator
[docs] def on_batch_end(self, net, *batch, training=False, **kwargs):
# If training is false, assume that our loader has indices for this
# batch
if not training:
epoch_cbs = []
for name, cb in self.callbacks_:
if isinstance(cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)) and (
hasattr(cb, 'window_inds_')) and (not cb.on_train):
epoch_cbs.append(cb)
# for trialwise decoding stuffs it might also be we don't have
# cropped loader, so no indices there
if len(epoch_cbs) > 0:
assert self._last_window_inds_ is not None
for cb in epoch_cbs:
cb.window_inds_.append(self._last_window_inds_)
self._last_window_inds_ = None
[docs] def predict_with_window_inds_and_ys(self, dataset):
self.module.eval()
preds = []
i_window_in_trials = []
i_window_stops = []
window_ys = []
for X, y, i in self.get_iterator(dataset, drop_index=False):
i_window_in_trials.append(i[0].cpu().numpy())
i_window_stops.append(i[2].cpu().numpy())
with torch.no_grad():
preds.append(to_numpy(self.module.forward(X.to(self.device))))
window_ys.append(y.cpu().numpy())
preds = np.concatenate(preds)
i_window_in_trials = np.concatenate(i_window_in_trials)
i_window_stops = np.concatenate(i_window_stops)
window_ys = np.concatenate(window_ys)
return dict(
preds=preds, i_window_in_trials=i_window_in_trials,
i_window_stops=i_window_stops, window_ys=window_ys)
# Removes default EpochScoring callback computing 'accuracy' to work properly
# with cropped decoding.
@property
def _default_callbacks(self):
return [
("epoch_timer", EpochTimer()),
(
"train_loss",
BatchScoring(
train_loss_score,
name="train_loss",
on_train=True,
target_extractor=noop,
),
),
(
"valid_loss",
BatchScoring(
valid_loss_score, name="valid_loss", target_extractor=noop,
),
),
("print_log", PrintLog()),
]
[docs] def predict_proba(self, X):
"""Return the output of the module's forward method as a numpy
array. In case of cropped decoding returns averaged values for
each trial.
If the module's forward method returns multiple outputs as a
tuple, it is assumed that the first output contains the
relevant information and the other values are ignored.
If all values are relevant or module's output for each crop
is needed, consider using :func:`~skorch.NeuralNet.forward`
instead.
Parameters
----------
X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
* numpy arrays
* torch tensors
* pandas DataFrame or Series
* scipy sparse CSR matrices
* a dictionary of the former three
* a list/tuple of the former three
* a Dataset
If this doesn't work with your data, you have to pass a
``Dataset`` that can deal with the data.
Returns
-------
y_proba : numpy ndarray
"""
y_pred = super().predict_proba(X)
# Normally, we have to average the predictions across crops/timesteps
# to get one prediction per window/trial
# Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
# However, when predictions are computed outside of CroppedTrialEpochScoring
# we have to average predictions, hence the check if len(y_pred.shape) == 3
if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
return y_pred.mean(axis=-1)
else:
return y_pred
[docs] def predict(self, X):
"""Return class labels for samples in X.
Parameters
----------
X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
* numpy arrays
* torch tensors
* pandas DataFrame or Series
* scipy sparse CSR matrices
* a dictionary of the former three
* a list/tuple of the former three
* a Dataset
If this doesn't work with your data, you have to pass a
``Dataset`` that can deal with the data.
Returns
-------
y_pred : numpy ndarray
"""
return self.predict_proba(X).argmax(1)
[docs] def predict_trials(self, X, return_targets=True):
"""Create trialwise predictions and optionally also return trialwise
labels from cropped dataset.
Parameters
----------
X: braindecode.datasets.BaseConcatDataset
A braindecode dataset to be predicted.
return_targets: bool
If True, additionally returns the trial targets.
Returns
-------
trial_predictions: np.ndarray
3-dimensional array (n_trials x n_classes x n_predictions), where
the number of predictions depend on the chosen window size and the
receptive field of the network.
trial_labels: np.ndarray
2-dimensional array (n_trials x n_targets) where the number of
targets depends on the decoding paradigm and can be either a single
value, multiple values, or a sequence.
"""
if not self.cropped:
warnings.warn(
"This method was designed to predict trials in cropped mode. "
"Calling it when cropped is False will give the same result as "
"'.predict'.", UserWarning)
preds = self.predict(X)
if return_targets:
return preds, X.get_metadata()['target'].to_numpy()
return preds
return predict_trials(
module=self.module,
dataset=X,
return_targets=return_targets,
batch_size=self.batch_size,
num_workers=self.get_iterator(X, training=False).loader.num_workers,
)