Source code for braindecode.training.scoring

# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
#          Robin Tibor Schirrmeister <robintibor@gmail.com>
#          Alexandre Gramfort <alexandre.gramfort@inria.fr>
#          Lukas Gemein <l.gemein@gmail.com>
#          Mohammed Fattouh <mo.fattouh@gmail.com>
#
# License: BSD-3

from contextlib import contextmanager
import warnings

import numpy as np
import torch
from mne.utils.check import check_version
from skorch.callbacks.scoring import EpochScoring
from skorch.utils import to_numpy
from skorch.dataset import unpack_data
from torch.utils.data import DataLoader


[docs]def trial_preds_from_window_preds( preds, i_window_in_trials, i_stop_in_trials): """ Assigning window predictions to trials while removing duplicate predictions. Parameters ---------- preds: list of ndarrays (atleast 2darrays) List of window predictions, in each window prediction time is in axis=1 i_window_in_trials: list Index/number of window in trial i_stop_in_trials: list stop position of window in trial Returns ------- preds_per_trial: list of ndarrays Predictions in each trial, duplicates removed """ assert len(preds) == len(i_window_in_trials) == len(i_stop_in_trials), ( f'{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}') # Algorithm for assigning window predictions to trials # while removing duplicate predictions: # Loop through windows: # In each iteration you have predictions (assumed: #classes x #timesteps, # or at least #timesteps must be in axis=1) # and you have i_window_in_trial, i_stop_in_trial # (i_trial removed from variable names for brevity) # You first check if the i_window_in_trial is 1 larger # than in last iteration, then you are still in the same trial # Otherwise you are in a new trial # If you are in the same trial, you check for duplicate predictions # Only take predictions that are after (inclusive) # the stop of the last iteration (i.e., the index of final prediction # in the last iteration) # Then add the duplicate-removed predictions from this window # to predictions for current trial preds_per_trial = [] cur_trial_preds = [] i_last_stop = None i_last_window = -1 for window_preds, i_window, i_stop in zip( preds, i_window_in_trials, i_stop_in_trials): window_preds = np.array(window_preds) if i_window != (i_last_window + 1): assert i_window == 0, ( "window numbers in new trial should start from 0") preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1)) cur_trial_preds = [] i_last_stop = None if i_last_stop is not None: # Remove duplicates n_needed_preds = i_stop - i_last_stop window_preds = window_preds[:, -n_needed_preds:] cur_trial_preds.append(window_preds) i_last_window = i_window i_last_stop = i_stop # add last trial preds preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1)) return preds_per_trial
@contextmanager def _cache_net_forward_iter(net, use_caching, y_preds): """Caching context for ``skorch.NeuralNet`` instance. Returns a modified version of the net whose ``forward_iter`` method will subsequently return cached predictions. Leaving the context will undo the overwrite of the ``forward_iter`` method. """ if not use_caching: yield net return y_preds = iter(y_preds) # pylint: disable=unused-argument def cached_forward_iter(*args, device=net.device, **kwargs): for yp in y_preds: yield yp.to(device=device) net.forward_iter = cached_forward_iter try: yield net finally: # By setting net.forward_iter we define an attribute # `forward_iter` that precedes the bound method # `forward_iter`. By deleting the entry from the attribute # dict we undo this. del net.__dict__["forward_iter"]
[docs]class CroppedTrialEpochScoring(EpochScoring): """ Class to compute scores for trials from a model that predicts (super)crops. """ # XXX needs a docstring !!! def __init__( self, scoring, lower_is_better=True, on_train=False, name=None, target_extractor=to_numpy, use_caching=True, ): super().__init__( scoring=scoring, lower_is_better=lower_is_better, on_train=on_train, name=name, target_extractor=target_extractor, use_caching=use_caching, ) if not self.on_train: self.window_inds_ = [] def _initialize_cache(self): super()._initialize_cache() self.crops_to_trials_computed = False self.y_trues_ = [] self.y_preds_ = [] if not self.on_train: self.window_inds_ = []
[docs] def on_batch_end( self, net, batch, y_pred, training, **kwargs): # Skorch saves the predictions without moving them from GPU # https://github.com/skorch-dev/skorch/blob/fe71e3d55a4ae5f5f94ef7bdfc00fca3b3fd267f/skorch/callbacks/scoring.py#L385 # This can cause memory issues in case of a large number of predictions # Therefore here we move them to CPU already super().on_batch_end(net, batch, y_pred, training, **kwargs) if self.use_caching and training == self.on_train: self.y_preds_[-1] = self.y_preds_[-1].cpu()
[docs] def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs): assert self.use_caching if not self.crops_to_trials_computed: if self.on_train: # Prevent that rng state of torch is changed by # creation+usage of iterator rng_state = torch.random.get_rng_state() pred_results = net.predict_with_window_inds_and_ys( dataset_train) torch.random.set_rng_state(rng_state) else: pred_results = {} pred_results['i_window_in_trials'] = np.concatenate( [i[0].cpu().numpy() for i in self.window_inds_] ) pred_results['i_window_stops'] = np.concatenate( [i[2].cpu().numpy() for i in self.window_inds_] ) pred_results['preds'] = np.concatenate( [y_pred.cpu().numpy() for y_pred in self.y_preds_]) pred_results['window_ys'] = np.concatenate( [y.cpu().numpy() for y in self.y_trues_]) # A new trial starts # when the index of the window in trials # does not increment by 1 # Add dummy infinity at start window_0_per_trial_mask = np.diff( pred_results['i_window_in_trials'], prepend=[np.inf]) != 1 trial_ys = pred_results['window_ys'][window_0_per_trial_mask] trial_preds = trial_preds_from_window_preds( pred_results['preds'], pred_results['i_window_in_trials'], pred_results['i_window_stops']) # Average across the timesteps of each trial so we have per-trial # predictions already, these will be just passed through the forward # method of the classifier/regressor to the skorch scoring function. # trial_preds is a list, each item is a 2d array classes x time y_preds_per_trial = np.array( [np.mean(p, axis=1) for p in trial_preds] ) # Move into format expected by skorch (list of torch tensors) y_preds_per_trial = [torch.tensor(y_preds_per_trial)] # Store the computed trial preds for all Cropped Callbacks # that are also on same set cbs = net.callbacks_ epoch_cbs = [ cb for name, cb in cbs if isinstance(cb, CroppedTrialEpochScoring) and ( cb.on_train == self.on_train) ] for cb in epoch_cbs: cb.y_preds_ = y_preds_per_trial cb.y_trues_ = trial_ys cb.crops_to_trials_computed = True dataset = dataset_train if self.on_train else dataset_valid with _cache_net_forward_iter( net, self.use_caching, self.y_preds_ ) as cached_net: current_score = self._scoring(cached_net, dataset, self.y_trues_) self._record_score(net.history, current_score) return
[docs]class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring): """ Class to compute scores for trials from a model that predicts (super)crops with time series target. """
[docs] def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs): assert self.use_caching if not self.crops_to_trials_computed: if self.on_train: # Prevent that rng state of torch is changed by # creation+usage of iterator rng_state = torch.random.get_rng_state() pred_results = net.predict_with_window_inds_and_ys( dataset_train) torch.random.set_rng_state(rng_state) else: pred_results = {} pred_results['i_window_in_trials'] = np.concatenate( [i[0].cpu().numpy() for i in self.window_inds_] ) pred_results['i_window_stops'] = np.concatenate( [i[2].cpu().numpy() for i in self.window_inds_] ) pred_results['preds'] = np.concatenate( [y_pred.cpu().numpy() for y_pred in self.y_preds_]) pred_results['window_ys'] = np.concatenate( [y.cpu().numpy() for y in self.y_trues_]) num_preds = pred_results['preds'][-1].shape[-1] # slice the targets to fit preds shape pred_results['window_ys'] = [ targets[:, -num_preds:] for targets in pred_results['window_ys'] ] trial_preds = trial_preds_from_window_preds( pred_results['preds'], pred_results['i_window_in_trials'], pred_results['i_window_stops']) trial_ys = trial_preds_from_window_preds( pred_results['window_ys'], pred_results['i_window_in_trials'], pred_results['i_window_stops']) # the output is a list of predictions/targets per trial where each item is a # timeseries of predictions/targets of shape (n_classes x timesteps) # mask NaNs form targets preds = np.hstack(trial_preds) # n_classes x timesteps in all trials targets = np.hstack(trial_ys) # create valid targets mask mask = ~np.isnan(targets) # select valid targets that have a matching predictions masked_targets = targets[mask] # For classification there is only one row in targets and n_classes rows in preds if mask.shape[0] != preds.shape[0]: masked_preds = preds[:, mask[0, :]] else: masked_preds = preds[mask] # Store the computed trial preds for all Cropped Callbacks # that are also on same set cbs = net.callbacks_ epoch_cbs = [ cb for name, cb in cbs if isinstance(cb, CroppedTimeSeriesEpochScoring) and ( cb.on_train == self.on_train) ] masked_preds = [torch.tensor(masked_preds.T)] for cb in epoch_cbs: cb.y_preds_ = masked_preds cb.y_trues_ = masked_targets.T cb.crops_to_trials_computed = True dataset = dataset_train if self.on_train else dataset_valid with _cache_net_forward_iter( net, self.use_caching, self.y_preds_ ) as cached_net: current_score = self._scoring(cached_net, dataset, self.y_trues_) self._record_score(net.history, current_score)
[docs]class PostEpochTrainScoring(EpochScoring): """ Epoch Scoring class that recomputes predictions after the epoch on the training in validation mode. Note: For unknown reasons, this affects global random generator and therefore all results may change slightly if you add this scoring callback. Parameters ---------- scoring : None, str, or callable (default=None) If None, use the ``score`` method of the model. If str, it should be a valid sklearn scorer (e.g. "f1", "accuracy"). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the ``scoring`` parameter in sklearn's ``GridSearchCV`` et al. lower_is_better : bool (default=True) Whether lower scores should be considered better or worse. name : str or None (default=None) If not an explicit string, tries to infer the name from the ``scoring`` argument. target_extractor : callable (default=to_numpy) This is called on y before it is passed to scoring. """ def __init__( self, scoring, lower_is_better=True, name=None, target_extractor=to_numpy, ): super().__init__( scoring=scoring, lower_is_better=lower_is_better, on_train=True, name=name, target_extractor=target_extractor, use_caching=False, )
[docs] def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs): if len(self.y_preds_) == 0: dataset = net.get_dataset(dataset_train) # Prevent that rng state of torch is changed by # creation+usage of iterator # Unfortunatenly calling __iter__() of a pytorch # DataLoader will change the random state # Note line below setting rng state back rng_state = torch.random.get_rng_state() iterator = net.get_iterator(dataset, training=False) y_preds = [] y_test = [] for batch in iterator: batch_X, batch_y = unpack_data(batch) # TODO: remove after skorch 0.10 release if not check_version('skorch', min_version='0.10.1'): yp = net.evaluation_step(batch_X, training=False) # X, y unpacking has been pushed downstream in skorch 0.10 else: yp = net.evaluation_step(batch, training=False) yp = yp.to(device="cpu") y_test.append(self.target_extractor(batch_y)) y_preds.append(yp) y_test = np.concatenate(y_test) torch.random.set_rng_state(rng_state) # Adding the recomputed preds to all other # instances of PostEpochTrainScoring of this # Skorch-Net (NeuralNet, BraindecodeClassifier etc.) # (They will be reinitialized to empty lists by skorch # each epoch) cbs = net.callbacks_ epoch_cbs = [ cb for name, cb in cbs if isinstance(cb, PostEpochTrainScoring) ] for cb in epoch_cbs: cb.y_preds_ = y_preds cb.y_trues_ = y_test # y pred should be same as self.y_preds_ # Unclear if this also leads to any # random generator call? with _cache_net_forward_iter( net, use_caching=True, y_preds=self.y_preds_ ) as cached_net: current_score = self._scoring( cached_net, dataset_train, self.y_trues_ ) self._record_score(net.history, current_score)
[docs]def predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0): """Create trialwise predictions and optionally also return trialwise labels from cropped dataset given module. Parameters ---------- module: torch.nn.Module A pytorch model implementing forward. dataset: braindecode.datasets.BaseConcatDataset A braindecode dataset to be predicted. return_targets: bool If True, additionally returns the trial targets. batch_size: int The batch size used to iterate the dataset. num_workers: int Number of workers used in DataLoader to iterate the dataset. Returns ------- trial_predictions: np.ndarray 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network. trial_labels: np.ndarray 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence. """ # we have a cropped dataset if there exists at least one trial with more # than one compute window more_than_one_window = sum(dataset.get_metadata()['i_window_in_trial'] != 0) > 0 if not more_than_one_window: warnings.warn('This function was designed to predict trials from ' 'cropped datasets, which typically have multiple compute ' 'windows per trial. The given dataset has exactly one ' 'window per trial.') loader = DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, ) device = next(module.parameters()).device all_preds, all_ys, all_inds = [], [], [] with torch.no_grad(): for X, y, ind in loader: X = X.to(device) preds = module(X) all_preds.extend(preds.cpu().numpy().astype(np.float32)) all_ys.extend(y.cpu().numpy().astype(np.float32)) all_inds.extend(ind) preds_per_trial = trial_preds_from_window_preds( preds=all_preds, i_window_in_trials=torch.cat(all_inds[0::3]), i_stop_in_trials=torch.cat(all_inds[2::3]), ) preds_per_trial = np.array(preds_per_trial) if return_targets: if all_ys[0].shape == (): all_ys = np.array(all_ys) ys_per_trial = all_ys[ np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1] else: ys_per_trial = trial_preds_from_window_preds( preds=all_ys, i_window_in_trials=torch.cat(all_inds[0::3]), i_stop_in_trials=torch.cat(all_inds[2::3]), ) ys_per_trial = np.array(ys_per_trial) return preds_per_trial, ys_per_trial return preds_per_trial