Source code for braindecode.training.losses

# Authors: Robin Schirrmeister <robintibor@gmail.com>
#          Maciej Sliwowski <maciek.sliwowski@gmail.com>
#          Mohammed Fattouh <mo.fattouh@gmail.com>
#
# License: BSD (3-clause)

import torch
from torch import nn


[docs]class CroppedLoss(nn.Module): """Compute Loss after averaging predictions across time. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time)""" def __init__(self, loss_function): super().__init__() self.loss_function = loss_function
[docs] def forward(self, preds, targets): """Forward pass. Parameters ---------- preds: torch.Tensor Model's prediction with shape (batch_size, n_classes, n_times). targets: torch.Tensor Target labels with shape (batch_size, n_classes, n_times). """ avg_preds = torch.mean(preds, dim=2) avg_preds = avg_preds.squeeze(dim=1) return self.loss_function(avg_preds, targets)
[docs]class TimeSeriesLoss(nn.Module): """Compute Loss between timeseries targets and predictions. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time) Assumes targets are in shape: n_batch size x n_classes x window_len (in time) If the targets contain NaNs, the NaNs will be masked out and the loss will be only computed for predictions valid corresponding to valid target values.""" def __init__(self, loss_function): super().__init__() self.loss_function = loss_function
[docs] def forward(self, preds, targets): """Forward pass. Parameters ---------- preds: torch.Tensor Model's prediction with shape (batch_size, n_classes, n_times). targets: torch.Tensor Target labels with shape (batch_size, n_classes, n_times). """ n_preds = preds.shape[-1] # slice the targets to fit preds shape targets = targets[:, :, -n_preds:] # create valid targets mask mask = ~torch.isnan(targets) # select valid targets that have a matching predictions masked_targets = targets[mask] masked_preds = preds[mask] return self.loss_function(masked_preds, masked_targets)
[docs]def mixup_criterion(preds, target): """Implements loss for Mixup for EEG data. See [1]_. Implementation based on [2]_. Parameters ---------- preds : torch.Tensor Predictions from the model. target : torch.Tensor | list of torch.Tensor For predictions without mixup, the targets as a tensor. If mixup has been applied, a list containing the targets of the two mixed samples and the mixing coefficients as tensors. Returns ------- loss : float The loss value. References ---------- .. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz mixup: Beyond Empirical Risk Minimization Online: https://arxiv.org/abs/1710.09412 .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py """ if len(target) == 3: # unpack target y_a, y_b, lam = target # compute loss per sample loss_a = torch.nn.functional.nll_loss(preds, y_a, reduction='none') loss_b = torch.nn.functional.nll_loss(preds, y_b, reduction='none') # compute weighted mean ret = torch.mul(lam, loss_a) + torch.mul(1 - lam, loss_b) return ret.mean() else: return torch.nn.functional.nll_loss(preds, target)