braindecode.training.CroppedLoss

class braindecode.training.CroppedLoss(loss_function)

Compute Loss after averaging predictions across time. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time)

Methods

forward(preds, targets)

Forward pass.

Parameters
preds: torch.Tensor

Model’s prediction with shape (batch_size, n_classes, n_times).

targets: torch.Tensor

Target labels with shape (batch_size, n_classes, n_times).

Examples using braindecode.training.CroppedLoss