braindecode.training package#

Functionality for skorch-based training.

class braindecode.training.CroppedLoss(loss_function)[source]#

Bases: Module

Compute Loss after averaging predictions across time. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time)

forward(preds, targets)[source]#

Forward pass.

Parameters:
  • preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).

  • targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).

training: bool#
class braindecode.training.CroppedTimeSeriesEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#

Bases: CroppedTrialEpochScoring

Class to compute scores for trials from a model that predicts (super)crops with time series target.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]#

Called at the end of each epoch.

class braindecode.training.CroppedTrialEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#

Bases: EpochScoring

Class to compute scores for trials from a model that predicts (super)crops.

on_batch_end(net, batch, y_pred, training, **kwargs)[source]#

Called at the end of each batch.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]#

Called at the end of each epoch.

class braindecode.training.PostEpochTrainScoring(scoring, lower_is_better=True, name=None, target_extractor=<function to_numpy>)[source]#

Bases: EpochScoring

Epoch Scoring class that recomputes predictions after the epoch on the training in validation mode.

Note: For unknown reasons, this affects global random generator and therefore all results may change slightly if you add this scoring callback.

Parameters:
  • scoring (None, str, or callable (default=None)) – If None, use the score method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

  • lower_is_better (bool (default=True)) – Whether lower scores should be considered better or worse.

  • name (str or None (default=None)) – If not an explicit string, tries to infer the name from the scoring argument.

  • target_extractor (callable (default=to_numpy)) – This is called on y before it is passed to scoring.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]#

Called at the end of each epoch.

class braindecode.training.TimeSeriesLoss(loss_function)[source]#

Bases: Module

Compute Loss between timeseries targets and predictions. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time) Assumes targets are in shape: n_batch size x n_classes x window_len (in time) If the targets contain NaNs, the NaNs will be masked out and the loss will be only computed for predictions valid corresponding to valid target values.

forward(preds, targets)[source]#

Forward pass.

Parameters:
  • preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).

  • targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).

training: bool#
braindecode.training.mixup_criterion(preds, target)[source]#

Implements loss for Mixup for EEG data. See [1].

Implementation based on [2].

Parameters:
  • preds (torch.Tensor) – Predictions from the model.

  • target (torch.Tensor | list of torch.Tensor) – For predictions without mixup, the targets as a tensor. If mixup has been applied, a list containing the targets of the two mixed samples and the mixing coefficients as tensors.

Returns:

loss – The loss value.

Return type:

float

References

[1]

Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz mixup: Beyond Empirical Risk Minimization Online: https://arxiv.org/abs/1710.09412

braindecode.training.predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0)[source]#

Create trialwise predictions and optionally also return trialwise labels from cropped dataset given module.

Parameters:
  • module (torch.nn.Module) – A pytorch model implementing forward.

  • dataset (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.

  • return_targets (bool) – If True, additionally returns the trial targets.

  • batch_size (int) – The batch size used to iterate the dataset.

  • num_workers (int) – Number of workers used in DataLoader to iterate the dataset.

Returns:

  • trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.

  • trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.

braindecode.training.trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials)[source]#

Assigning window predictions to trials while removing duplicate predictions.

Parameters:
  • preds (list of ndarrays (at least 2darrays)) –

    List of window predictions, in each window prediction

    time is in axis=1

  • i_window_in_trials (list) – Index/number of window in trial

  • i_stop_in_trials (list) – stop position of window in trial

Returns:

preds_per_trial – Predictions in each trial, duplicates removed

Return type:

list of ndarrays

Submodules#

braindecode.training.callbacks module#

class braindecode.training.callbacks.MaxNormConstraintCallback[source]#

Bases: Callback

on_batch_end(net, training, *args, **kwargs)[source]#

Called at the end of each batch.

braindecode.training.losses module#

class braindecode.training.losses.CroppedLoss(loss_function)[source]#

Bases: Module

Compute Loss after averaging predictions across time. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time)

forward(preds, targets)[source]#

Forward pass.

Parameters:
  • preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).

  • targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).

training: bool#
class braindecode.training.losses.TimeSeriesLoss(loss_function)[source]#

Bases: Module

Compute Loss between timeseries targets and predictions. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time) Assumes targets are in shape: n_batch size x n_classes x window_len (in time) If the targets contain NaNs, the NaNs will be masked out and the loss will be only computed for predictions valid corresponding to valid target values.

forward(preds, targets)[source]#

Forward pass.

Parameters:
  • preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).

  • targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).

training: bool#
braindecode.training.losses.mixup_criterion(preds, target)[source]#

Implements loss for Mixup for EEG data. See [1].

Implementation based on [2].

Parameters:
  • preds (torch.Tensor) – Predictions from the model.

  • target (torch.Tensor | list of torch.Tensor) – For predictions without mixup, the targets as a tensor. If mixup has been applied, a list containing the targets of the two mixed samples and the mixing coefficients as tensors.

Returns:

loss – The loss value.

Return type:

float

References

[1]

Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz mixup: Beyond Empirical Risk Minimization Online: https://arxiv.org/abs/1710.09412

braindecode.training.scoring module#

class braindecode.training.scoring.CroppedTimeSeriesEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#

Bases: CroppedTrialEpochScoring

Class to compute scores for trials from a model that predicts (super)crops with time series target.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]#

Called at the end of each epoch.

class braindecode.training.scoring.CroppedTrialEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#

Bases: EpochScoring

Class to compute scores for trials from a model that predicts (super)crops.

on_batch_end(net, batch, y_pred, training, **kwargs)[source]#

Called at the end of each batch.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]#

Called at the end of each epoch.

class braindecode.training.scoring.PostEpochTrainScoring(scoring, lower_is_better=True, name=None, target_extractor=<function to_numpy>)[source]#

Bases: EpochScoring

Epoch Scoring class that recomputes predictions after the epoch on the training in validation mode.

Note: For unknown reasons, this affects global random generator and therefore all results may change slightly if you add this scoring callback.

Parameters:
  • scoring (None, str, or callable (default=None)) – If None, use the score method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

  • lower_is_better (bool (default=True)) – Whether lower scores should be considered better or worse.

  • name (str or None (default=None)) – If not an explicit string, tries to infer the name from the scoring argument.

  • target_extractor (callable (default=to_numpy)) – This is called on y before it is passed to scoring.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]#

Called at the end of each epoch.

braindecode.training.scoring.predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0)[source]#

Create trialwise predictions and optionally also return trialwise labels from cropped dataset given module.

Parameters:
  • module (torch.nn.Module) – A pytorch model implementing forward.

  • dataset (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.

  • return_targets (bool) – If True, additionally returns the trial targets.

  • batch_size (int) – The batch size used to iterate the dataset.

  • num_workers (int) – Number of workers used in DataLoader to iterate the dataset.

Returns:

  • trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.

  • trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.

braindecode.training.scoring.trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials)[source]#

Assigning window predictions to trials while removing duplicate predictions.

Parameters:
  • preds (list of ndarrays (at least 2darrays)) –

    List of window predictions, in each window prediction

    time is in axis=1

  • i_window_in_trials (list) – Index/number of window in trial

  • i_stop_in_trials (list) – stop position of window in trial

Returns:

preds_per_trial – Predictions in each trial, duplicates removed

Return type:

list of ndarrays