Source code for braindecode.visualization.gradients
# Authors: Robin Schirrmeister <robintibor@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
from skorch.utils import to_numpy, to_tensor
import torch
[docs]def compute_amplitude_gradients(model, dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
drop_last=False, shuffle=False)
all_amp_grads = []
for batch_X, _, _ in loader:
this_amp_grads = compute_amplitude_gradients_for_X(model, batch_X, )
all_amp_grads.append(this_amp_grads)
all_amp_grads = np.concatenate(all_amp_grads, axis=1)
return all_amp_grads
def compute_amplitude_gradients_for_X(model, X):
device = next(model.parameters()).device
ffted = np.fft.rfft(X, axis=2)
amps = np.abs(ffted)
phases = np.angle(ffted)
amps_th = to_tensor(amps.astype(np.float32), device=device).requires_grad_(True)
phases_th = to_tensor(phases.astype(np.float32), device=device).requires_grad_(True)
fft_coefs = amps_th.unsqueeze(-1) * torch.stack(
(torch.cos(phases_th), torch.sin(phases_th)), dim=-1)
fft_coefs = fft_coefs.squeeze(3)
try:
complex_fft_coefs = torch.view_as_complex(fft_coefs)
iffted = torch.fft.irfft(
complex_fft_coefs, n=X.shape[2], dim=2)
except AttributeError:
iffted = torch.irfft( # Deprecated since 1.7
fft_coefs, signal_ndim=1, signal_sizes=(X.shape[2],))
outs = model(iffted)
n_filters = outs.shape[1]
amp_grads_per_filter = np.full((n_filters,) + ffted.shape,
np.nan, dtype=np.float32)
for i_filter in range(n_filters):
mean_out = torch.mean(outs[:, i_filter])
mean_out.backward(retain_graph=True)
amp_grads = to_numpy(amps_th.grad.clone())
amp_grads_per_filter[i_filter] = amp_grads
amps_th.grad.zero_()
assert not np.any(np.isnan(amp_grads_per_filter))
return amp_grads_per_filter