braindecode.training.CroppedLoss¶
- class braindecode.training.CroppedLoss(loss_function)¶
Compute Loss after averaging predictions across time. Assumes predictions are in shape: n_batch size x n_classes x n_predictions (in time)
Methods
- forward(preds, targets)¶
Forward pass.
- Parameters
- preds: torch.Tensor
Model’s prediction with shape (batch_size, n_classes, n_times).
- targets: torch.Tensor
Target labels with shape (batch_size, n_classes, n_times).