braindecode.training package#
Functionality for skorch-based training.
- class braindecode.training.CroppedLoss(loss_function)[source]#
Bases:
Module
Compute Loss after averaging predictions across time. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time)
- forward(preds, targets)[source]#
Forward pass.
- Parameters:
preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).
targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).
- class braindecode.training.CroppedTimeSeriesEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#
Bases:
CroppedTrialEpochScoring
Class to compute scores for trials from a model that predicts (super)crops with time series target.
- class braindecode.training.CroppedTrialEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#
Bases:
EpochScoring
Class to compute scores for trials from a model that predicts (super)crops.
- class braindecode.training.PostEpochTrainScoring(scoring, lower_is_better=True, name=None, target_extractor=<function to_numpy>)[source]#
Bases:
EpochScoring
Epoch Scoring class that recomputes predictions after the epoch on the training in validation mode.
Note: For unknown reasons, this affects global random generator and therefore all results may change slightly if you add this scoring callback.
- Parameters:
scoring (None, str, or callable (default=None)) – If None, use the
score
method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoring
parameter in sklearn’sGridSearchCV
et al.lower_is_better (bool (default=True)) – Whether lower scores should be considered better or worse.
name (str or None (default=None)) – If not an explicit string, tries to infer the name from the
scoring
argument.target_extractor (callable (default=to_numpy)) – This is called on y before it is passed to scoring.
- class braindecode.training.TimeSeriesLoss(loss_function)[source]#
Bases:
Module
Compute Loss between timeseries targets and predictions. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time) Assumes targets are in shape: n_batch size x n_classes x window_len (in time) If the targets contain NaNs, the NaNs will be masked out and the loss will be only computed for predictions valid corresponding to valid target values.
- forward(preds, targets)[source]#
Forward pass.
- Parameters:
preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).
targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).
- braindecode.training.mixup_criterion(preds, target)[source]#
Implements loss for Mixup for EEG data. See [1].
Implementation based on [2].
- Parameters:
preds (torch.Tensor) – Predictions from the model.
target (torch.Tensor | list of torch.Tensor) – For predictions without mixup, the targets as a tensor. If mixup has been applied, a list containing the targets of the two mixed samples and the mixing coefficients as tensors.
- Returns:
loss – The loss value.
- Return type:
References
[1]Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz mixup: Beyond Empirical Risk Minimization Online: https://arxiv.org/abs/1710.09412
- braindecode.training.predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0)[source]#
Create trialwise predictions and optionally also return trialwise labels from cropped dataset given module.
- Parameters:
module (torch.nn.Module) – A pytorch model implementing forward.
dataset (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.
return_targets (bool) – If True, additionally returns the trial targets.
batch_size (int) – The batch size used to iterate the dataset.
num_workers (int) – Number of workers used in DataLoader to iterate the dataset.
- Returns:
trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.
- braindecode.training.trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials)[source]#
Assigning window predictions to trials while removing duplicate predictions.
- Parameters:
- Returns:
preds_per_trial – Predictions in each trial, duplicates removed
- Return type:
list of ndarrays
Submodules#
braindecode.training.callbacks module#
braindecode.training.losses module#
- class braindecode.training.losses.CroppedLoss(loss_function)[source]#
Bases:
Module
Compute Loss after averaging predictions across time. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time)
- forward(preds, targets)[source]#
Forward pass.
- Parameters:
preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).
targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).
- class braindecode.training.losses.TimeSeriesLoss(loss_function)[source]#
Bases:
Module
Compute Loss between timeseries targets and predictions. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time) Assumes targets are in shape: n_batch size x n_classes x window_len (in time) If the targets contain NaNs, the NaNs will be masked out and the loss will be only computed for predictions valid corresponding to valid target values.
- forward(preds, targets)[source]#
Forward pass.
- Parameters:
preds (torch.Tensor) – Model’s prediction with shape (batch_size, n_classes, n_times).
targets (torch.Tensor) – Target labels with shape (batch_size, n_classes, n_times).
- braindecode.training.losses.mixup_criterion(preds, target)[source]#
Implements loss for Mixup for EEG data. See [1].
Implementation based on [2].
- Parameters:
preds (torch.Tensor) – Predictions from the model.
target (torch.Tensor | list of torch.Tensor) – For predictions without mixup, the targets as a tensor. If mixup has been applied, a list containing the targets of the two mixed samples and the mixing coefficients as tensors.
- Returns:
loss – The loss value.
- Return type:
References
[1]Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz mixup: Beyond Empirical Risk Minimization Online: https://arxiv.org/abs/1710.09412
braindecode.training.scoring module#
- class braindecode.training.scoring.CroppedTimeSeriesEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#
Bases:
CroppedTrialEpochScoring
Class to compute scores for trials from a model that predicts (super)crops with time series target.
- class braindecode.training.scoring.CroppedTrialEpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]#
Bases:
EpochScoring
Class to compute scores for trials from a model that predicts (super)crops.
- class braindecode.training.scoring.PostEpochTrainScoring(scoring, lower_is_better=True, name=None, target_extractor=<function to_numpy>)[source]#
Bases:
EpochScoring
Epoch Scoring class that recomputes predictions after the epoch on the training in validation mode.
Note: For unknown reasons, this affects global random generator and therefore all results may change slightly if you add this scoring callback.
- Parameters:
scoring (None, str, or callable (default=None)) – If None, use the
score
method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoring
parameter in sklearn’sGridSearchCV
et al.lower_is_better (bool (default=True)) – Whether lower scores should be considered better or worse.
name (str or None (default=None)) – If not an explicit string, tries to infer the name from the
scoring
argument.target_extractor (callable (default=to_numpy)) – This is called on y before it is passed to scoring.
- braindecode.training.scoring.predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0)[source]#
Create trialwise predictions and optionally also return trialwise labels from cropped dataset given module.
- Parameters:
module (torch.nn.Module) – A pytorch model implementing forward.
dataset (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.
return_targets (bool) – If True, additionally returns the trial targets.
batch_size (int) – The batch size used to iterate the dataset.
num_workers (int) – Number of workers used in DataLoader to iterate the dataset.
- Returns:
trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.
- braindecode.training.scoring.trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials)[source]#
Assigning window predictions to trials while removing duplicate predictions.
- Parameters:
- Returns:
preds_per_trial – Predictions in each trial, duplicates removed
- Return type:
list of ndarrays