braindecode.training.predict_trials#

braindecode.training.predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0)[source]#

Create trialwise predictions and optionally also return trialwise labels from cropped dataset given module.

Parameters:
  • module (torch.nn.Module) – A pytorch model implementing forward.

  • dataset (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.

  • return_targets (bool) – If True, additionally returns the trial targets.

  • batch_size (int) – The batch size used to iterate the dataset.

  • num_workers (int) – Number of workers used in DataLoader to iterate the dataset.

Returns:

  • trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.

  • trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.