braindecode.training.predict_trials#
- braindecode.training.predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0)[source]#
Create trialwise predictions and optionally also return trialwise labels from cropped dataset given module.
- Parameters
module (torch.nn.Module) – A pytorch model implementing forward.
dataset (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.
return_targets (bool) – If True, additionally returns the trial targets.
batch_size (int) – The batch size used to iterate the dataset.
num_workers (int) – Number of workers used in DataLoader to iterate the dataset.
- Returns
trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.