braindecode.visualization.amplitude_gradients#

braindecode.visualization.amplitude_gradients(model, x)[source]#

Per-batch amplitude gradients.

Parameters:
  • model (torch.nn.Module) – Model in eval mode (or otherwise deterministic for the given input). Must accept x of shape (batch, n_chans, n_times) and return outputs of shape (batch, n_outputs).

  • x (numpy.ndarray or torch.Tensor of shape (batch, n_chans, n_times)) – Input batch. Will be moved to model’s device.

Returns:

out[i] is the gradient of the mean of the i-th output unit w.r.t. the input amplitude spectrum, per trial. n_freqs is n_times // 2 + 1 (the size of an rfft).

Return type:

numpy.ndarray of shape (n_outputs, batch, n_chans, n_freqs)