# Authors: Cédric Rommel <cedric.rommel@inria.fr>
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
#
# License: BSD (3-clause)
from numbers import Real
import numpy as np
from scipy.interpolate import Rbf
from sklearn.utils import check_random_state
import torch
from torch.fft import fft, ifft
from torch.nn.functional import pad, one_hot
from mne.filter import notch_filter
[docs]def identity(X, y):
"""Identity operation.
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
"""
return X, y
[docs]def time_reverse(X, y):
"""Flip the time axis of each input.
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
"""
return torch.flip(X, [-1]), y
[docs]def sign_flip(X, y):
"""Flip the sign axis of each input.
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
"""
return -X, y
def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
rng = check_random_state(random_state)
random_phase = torch.from_numpy(
2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
).to(device)
return torch.cat([
torch.zeros((batch_size, c, 1), device=device),
random_phase,
-torch.flip(random_phase, [-1])
], dim=-1)
def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
rng = check_random_state(random_state)
random_phase = torch.from_numpy(
2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
).to(device)
return torch.cat([
torch.zeros((batch_size, c, 1), device=device),
random_phase,
torch.zeros((batch_size, c, 1), device=device),
-torch.flip(random_phase, [-1])
], dim=-1)
_new_random_fft_phase = {
0: _new_random_fft_phase_even,
1: _new_random_fft_phase_odd
}
[docs]def ft_surrogate(
X,
y,
phase_noise_magnitude,
channel_indep,
random_state=None
):
"""FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
and modified.
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
phase_noise_magnitude: float
Float between 0 and 1 setting the range over which the phase
pertubation is uniformly sampled:
[0, `phase_noise_magnitude` * 2 * `pi`].
channel_indep : bool
Whether to sample phase perturbations independently for each channel or
not. It is advised to set it to False when spatial information is
important for the task, like in BCI.
random_state: int | numpy.random.Generator, optional
Used to draw the phase perturbation. Defaults to None.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
References
----------
.. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
Clifford, G. D. (2018). Addressing Class Imbalance in Classification
Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
preprint arXiv:1806.08675.
"""
assert isinstance(
phase_noise_magnitude,
(Real, torch.FloatTensor, torch.cuda.FloatTensor)
) and 0 <= phase_noise_magnitude <= 1, (
f"eps must be a float beween 0 and 1. Got {phase_noise_magnitude}.")
f = fft(X.double(), dim=-1)
device = X.device
n = f.shape[-1]
random_phase = _new_random_fft_phase[n % 2](
f.shape[0],
f.shape[-2] if channel_indep else 1,
n,
device=device,
random_state=random_state
)
if not channel_indep:
random_phase = torch.tile(random_phase, (1, f.shape[-2], 1))
if isinstance(phase_noise_magnitude, torch.Tensor):
phase_noise_magnitude = phase_noise_magnitude.to(device)
f_shifted = f * torch.exp(phase_noise_magnitude * random_phase)
shifted = ifft(f_shifted, dim=-1)
transformed_X = shifted.real.float()
return transformed_X, y
def _pick_channels_randomly(X, p_pick, random_state):
rng = check_random_state(random_state)
batch_size, n_channels, _ = X.shape
# allows to use the same RNG
unif_samples = torch.as_tensor(
rng.uniform(0, 1, size=(batch_size, n_channels)),
dtype=torch.float,
device=X.device,
)
# equivalent to a 0s and 1s mask
return torch.sigmoid(1000*(unif_samples - p_pick))
[docs]def channels_dropout(X, y, p_drop, random_state=None):
"""Randomly set channels to flat signal.
Part of the CMSAugment policy proposed in [1]_
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
p_drop : float
Float between 0 and 1 setting the probability of dropping each channel.
random_state : int | numpy.random.Generator, optional
Seed to be used to instantiate numpy random number generator instance.
Defaults to None.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
References
----------
.. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
Learning from Heterogeneous EEG Signals with Differentiable Channel
Reordering. arXiv preprint arXiv:2010.13694.
"""
mask = _pick_channels_randomly(X, p_drop, random_state=random_state)
return X * mask.unsqueeze(-1), y
def _make_permutation_matrix(X, mask, random_state):
rng = check_random_state(random_state)
batch_size, n_channels, _ = X.shape
hard_mask = mask.round()
batch_permutations = torch.empty(
batch_size, n_channels, n_channels, device=X.device
)
for b, mask in enumerate(hard_mask):
channels_to_shuffle = torch.arange(n_channels)
channels_to_shuffle = channels_to_shuffle[mask.bool()]
channels_permutation = np.arange(n_channels)
channels_permutation[channels_to_shuffle] = rng.permutation(
channels_to_shuffle
)
channels_permutation = torch.as_tensor(
channels_permutation, dtype=torch.int64, device=X.device
)
batch_permutations[b, ...] = one_hot(channels_permutation)
return batch_permutations
[docs]def channels_shuffle(X, y, p_shuffle, random_state=None):
"""Randomly shuffle channels in EEG data matrix.
Part of the CMSAugment policy proposed in [1]_
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
p_shuffle: float | None
Float between 0 and 1 setting the probability of including the channel
in the set of permutted channels.
random_state: int | numpy.random.Generator, optional
Seed to be used to instantiate numpy random number generator instance.
Used to sample which channels to shuffle and to carry the shuffle.
Defaults to None.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
References
----------
.. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
Learning from Heterogeneous EEG Signals with Differentiable Channel
Reordering. arXiv preprint arXiv:2010.13694.
"""
if p_shuffle == 0:
return X, y
mask = _pick_channels_randomly(X, 1 - p_shuffle, random_state)
batch_permutations = _make_permutation_matrix(X, mask, random_state)
return torch.matmul(batch_permutations, X), y
[docs]def gaussian_noise(X, y, std, random_state=None):
"""Randomly add white Gaussian noise to all channels.
Suggested e.g. in [1]_, [2]_ and [3]_
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
std : float
Standard deviation to use for the additive noise.
random_state: int | numpy.random.Generator, optional
Seed to be used to instantiate numpy random number generator instance.
Defaults to None.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
References
----------
.. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
augmentation for eeg-based emotion recognition with deep convolutional
neural networks. In International Conference on Multimedia Modeling
(pp. 82-93).
.. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
Subject-aware contrastive learning for biosignals. arXiv preprint
arXiv:2007.04871.
.. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
Representation Learning for Electroencephalogram Classification. In
Machine Learning for Health (pp. 238-253). PMLR.
"""
rng = check_random_state(random_state)
if isinstance(std, torch.Tensor):
std = std.to(X.device)
noise = torch.from_numpy(
rng.normal(
loc=np.zeros(X.shape),
scale=1
),
).float().to(X.device) * std
transformed_X = X + noise
return transformed_X, y
[docs]def channels_permute(X, y, permutation):
"""Permute EEG channels according to fixed permutation matrix.
Suggested e.g. in [1]_
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
permutation : list
List of integers defining the new channels order.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
References
----------
.. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
(2018). HAMLET: interpretable human and machine co-learning technique.
arXiv preprint arXiv:1803.09702.
"""
return X[..., permutation, :], y
[docs]def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
"""Smoothly replace a contiguous part of all channels by zeros.
Originally proposed in [1]_ and [2]_
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
mask_start_per_sample : torch.tensor
Tensor of integers containing the position (in last dimension) where to
start masking the signal. Should have the same size as the first
dimension of X (i.e. one start position per example in the batch).
mask_len_samples : int
Number of consecutive samples to zero out.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
References
----------
.. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
Subject-aware contrastive learning for biosignals. arXiv preprint
arXiv:2007.04871.
.. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
Representation Learning for Electroencephalogram Classification. In
Machine Learning for Health (pp. 238-253). PMLR.
"""
batch_size, n_channels, seq_len = X.shape
t = torch.arange(seq_len, device=X.device).float()
t = t.repeat(batch_size, n_channels, 1)
mask_start_per_sample = mask_start_per_sample.view(-1, 1, 1)
s = 1000 / seq_len
mask = (torch.sigmoid(s * -(t - mask_start_per_sample)) +
torch.sigmoid(s * (t - mask_start_per_sample - mask_len_samples))
).float().to(X.device)
return X * mask, y
[docs]def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
"""Apply a band-stop filter with desired bandwidth at the desired frequency
position.
Suggested e.g. in [1]_ and [2]_
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
sfreq : float
Sampling frequency of the signals to be filtered.
bandwidth : float
Bandwidth of the filter, i.e. distance between the low and high cut
frequencies.
freqs_to_notch : array-like | None
Array of floats of size ``(batch_size,)`` containing the center of the
frequency band to filter out for each sample in the batch. Frequencies
should be greater than ``bandwidth/2 + transition`` and lower than
``sfreq/2 - bandwidth/2 - transition`` (where ``transition = 1 Hz``).
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
References
----------
.. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
Subject-aware contrastive learning for biosignals. arXiv preprint
arXiv:2007.04871.
.. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
Representation Learning for Electroencephalogram Classification. In
Machine Learning for Health (pp. 238-253). PMLR.
"""
if bandwidth == 0:
return X, y
transformed_X = X.clone()
for c, (sample, notched_freq) in enumerate(
zip(transformed_X, freqs_to_notch)):
sample = sample.cpu().numpy().astype(np.float64)
transformed_X[c] = torch.as_tensor(notch_filter(
sample,
Fs=sfreq,
freqs=notched_freq,
method='fir',
notch_widths=bandwidth,
verbose=False
))
return transformed_X, y
def _analytic_transform(x):
if torch.is_complex(x):
raise ValueError("x must be real.")
N = x.shape[-1]
f = fft(x, N, dim=-1)
h = torch.zeros_like(f)
if N % 2 == 0:
h[..., 0] = h[..., N // 2] = 1
h[..., 1:N // 2] = 2
else:
h[..., 0] = 1
h[..., 1:(N + 1) // 2] = 2
return ifft(f * h, dim=-1)
def _nextpow2(n):
"""Return the first integer N such that 2**N >= abs(n)."""
return int(np.ceil(np.log2(np.abs(n))))
def _frequency_shift(X, fs, f_shift):
"""
Shift the specified signal by the specified frequency.
See https://gist.github.com/lebedov/4428122
"""
# Pad the signal with zeros to prevent the FFT invoked by the transform
# from slowing down the computation:
n_channels, N_orig = X.shape[-2:]
N_padded = 2 ** _nextpow2(N_orig)
t = torch.arange(N_padded, device=X.device) / fs
padded = pad(X, (0, N_padded - N_orig))
analytical = _analytic_transform(padded)
if isinstance(f_shift, (float, int, np.ndarray, list)):
f_shift = torch.as_tensor(f_shift).float()
reshaped_f_shift = f_shift.repeat(
N_padded, n_channels, 1).T
shifted = analytical * torch.exp(2j * np.pi * reshaped_f_shift * t)
return shifted[..., :N_orig].real.float()
[docs]def frequency_shift(X, y, delta_freq, sfreq):
"""Adds a shift in the frequency domain to all channels.
Note that here, the shift is the same for all channels of a single example.
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
delta_freq : float
The amplitude of the frequency shift (in Hz).
sfreq : float
Sampling frequency of the signals to be transformed.
Returns
-------
torch.Tensor
Transformed inputs.
torch.Tensor
Transformed labels.
"""
transformed_X = _frequency_shift(
X=X,
fs=sfreq,
f_shift=delta_freq,
)
return transformed_X, y
def _torch_normalize_vectors(rr):
"""Normalize surface vertices."""
norm = torch.linalg.norm(rr, axis=1, keepdim=True)
mask = (norm > 0)
norm[~mask] = 1 # in case norm is zero, divide by 1
new_rr = rr / norm
return new_rr
def _torch_legval(x, c, tensor=True):
"""
Evaluate a Legendre series at points x.
If `c` is of length `n + 1`, this function returns the value:
.. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
The parameter `x` is converted to an array only if it is a tuple or a
list, otherwise it is treated as a scalar. In either case, either `x`
or its elements must support multiplication and addition both with
themselves and with the elements of `c`.
If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If
`c` is multidimensional, then the shape of the result depends on the
value of `tensor`. If `tensor` is true the shape will be c.shape[1:] +
x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that
scalars have shape (,).
Trailing zeros in the coefficients will be used in the evaluation, so
they should be avoided if efficiency is a concern.
Parameters
----------
x : array_like, compatible object
If `x` is a list or tuple, it is converted to an ndarray, otherwise
it is left unchanged and treated as a scalar. In either case, `x`
or its elements must support addition and multiplication with
with themselves and with the elements of `c`.
c : array_like
Array of coefficients ordered so that the coefficients for terms of
degree n are contained in c[n]. If `c` is multidimensional the
remaining indices enumerate multiple polynomials. In the two
dimensional case the coefficients may be thought of as stored in
the columns of `c`.
tensor : boolean, optional
If True, the shape of the coefficient array is extended with ones
on the right, one for each dimension of `x`. Scalars have dimension 0
for this action. The result is that every column of coefficients in
`c` is evaluated for every element of `x`. If False, `x` is broadcast
over the columns of `c` for the evaluation. This keyword is useful
when `c` is multidimensional. The default value is True.
.. versionadded:: 1.7.0
Returns
-------
values : ndarray, algebra_like
The shape of the return value is described above.
See Also
--------
legval2d, leggrid2d, legval3d, leggrid3d
Notes
-----
Code copied and modified from Numpy:
https://github.com/numpy/numpy/blob/v1.20.0/numpy/polynomial/legendre.py#L835-L920
Copyright (c) 2005-2021, NumPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
* Neither the name of the NumPy Developers nor the names of any
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
c = torch.as_tensor(c)
c = c.double()
if isinstance(x, (tuple, list)):
x = torch.as_tensor(x)
if isinstance(x, torch.Tensor) and tensor:
c = c.view(c.shape + (1,)*x.ndim)
c = c.to(x.device)
if len(c) == 1:
c0 = c[0]
c1 = 0
elif len(c) == 2:
c0 = c[0]
c1 = c[1]
else:
nd = len(c)
c0 = c[-2]
c1 = c[-1]
for i in range(3, len(c) + 1):
tmp = c0
nd = nd - 1
c0 = c[-i] - (c1*(nd - 1))/nd
c1 = tmp + (c1*x*(2*nd - 1))/nd
return c0 + c1*x
def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
"""Calculate spherical spline g function between points on a sphere.
Parameters
----------
cosang : array-like of float, shape(n_channels, n_channels)
cosine of angles between pairs of points on a spherical surface. This
is equivalent to the dot product of unit vectors.
stiffness : float
stiffness of the spline.
n_legendre_terms : int
number of Legendre terms to evaluate.
Returns
-------
G : np.ndrarray of float, shape(n_channels, n_channels)
The G matrix.
Notes
-----
Code copied and modified from MNE-Python:
https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L35
Copyright © 2011-2019, authors of MNE-Python
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
"""
factors = [(2 * n + 1) / (n ** stiffness * (n + 1) ** stiffness *
4 * np.pi)
for n in range(1, n_legendre_terms + 1)]
return _torch_legval(cosang, [0] + factors)
def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
"""Compute interpolation matrix based on spherical splines.
Implementation based on [1]_
Parameters
----------
pos_from : np.ndarray of float, shape(n_good_sensors, 3)
The positions to interpoloate from.
pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
The positions to interpoloate.
alpha : float
Regularization parameter. Defaults to 1e-5.
Returns
-------
interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
The interpolation matrix that maps good signals to the location
of bad signals.
References
----------
[1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
Spherical splines for scalp potential and current density mapping.
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
Notes
-----
Code copied and modified from MNE-Python:
https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L59
Copyright © 2011-2019, authors of MNE-Python
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
"""
pos_from = pos_from.clone()
pos_to = pos_to.clone()
n_from = pos_from.shape[0]
n_to = pos_to.shape[0]
# normalize sensor positions to sphere
pos_from = _torch_normalize_vectors(pos_from)
pos_to = _torch_normalize_vectors(pos_to)
# cosine angles between source positions
cosang_from = torch.matmul(pos_from, pos_from.T)
cosang_to_from = torch.matmul(pos_to, pos_from.T)
G_from = _torch_calc_g(cosang_from)
G_to_from = _torch_calc_g(cosang_to_from)
assert G_from.shape == (n_from, n_from)
assert G_to_from.shape == (n_to, n_from)
if alpha is not None:
G_from.flatten()[::len(G_from) + 1] += alpha
device = G_from.device
C = torch.vstack([
torch.hstack([G_from, torch.ones((n_from, 1), device=device)]),
torch.hstack([
torch.ones((1, n_from), device=device),
torch.as_tensor([[0]], device=device)])
])
try:
C_inv = torch.linalg.inv(C)
except torch._C._LinAlgError:
# There is a stability issue with pinv since torch v1.8.0
# see https://github.com/pytorch/pytorch/issues/75494
C_inv = torch.linalg.pinv(C.cpu()).to(device)
interpolation = torch.hstack([
G_to_from,
torch.ones((n_to, 1), device=device)
]).matmul(C_inv[:, :-1])
assert interpolation.shape == (n_to, n_from)
return interpolation
def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
sensors_positions_matrix = sensors_positions_matrix.to(X.device)
rot_sensors_matrices = [
rotation.matmul(sensors_positions_matrix) for rotation in rotations
]
if spherical:
interpolation_matrix = torch.stack(
[torch.as_tensor(
_torch_make_interpolation_matrix(
sensors_positions_matrix.T, rot_sensors_matrix.T
), device=X.device
).float() for rot_sensors_matrix in rot_sensors_matrices]
)
return torch.matmul(interpolation_matrix, X)
else:
transformed_X = X.clone()
sensors_positions = list(sensors_positions_matrix)
for s, rot_sensors_matrix in enumerate(rot_sensors_matrices):
rot_sensors_positions = list(rot_sensors_matrix.T)
for time in range(X.shape[-1]):
interpolator_t = Rbf(*sensors_positions, X[s, :, time])
transformed_X[s, :, time] = torch.from_numpy(
interpolator_t(*rot_sensors_positions)
)
return transformed_X
def _make_rotation_matrix(axis, angle, degrees=True):
assert axis in ['x', 'y', 'z'], "axis should be either x, y or z."
if isinstance(angle, (float, int, np.ndarray, list)):
angle = torch.as_tensor(angle)
if degrees:
angle = angle * np.pi / 180
device = angle.device
zero = torch.zeros(1, device=device)
rot = torch.stack([
torch.as_tensor([1, 0, 0], device=device),
torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
])
if axis == "x":
return rot
elif axis == "y":
rot = rot[[2, 0, 1], :]
return rot[:, [2, 0, 1]]
else:
rot = rot[[1, 2, 0], :]
return rot[:, [1, 2, 0]]
[docs]def sensors_rotation(X, y, sensors_positions_matrix, axis, angles,
spherical_splines):
"""Interpolates EEG signals over sensors rotated around the desired axis
with the desired angle.
Suggested in [1]_
Parameters
----------
X : torch.Tensor
EEG input example or batch.
y : torch.Tensor
EEG labels for the example or batch.
sensors_positions_matrix : numpy.ndarray
Matrix giving the positions of each sensor in a 3D cartesian coordinate
system. Should have shape (3, n_channels), where n_channels is the
number of channels. Standard 10-20 positions can be obtained from
``mne`` through::
>>> ten_twenty_montage = mne.channels.make_standard_montage(
... 'standard_1020'
... ).get_positions()['ch_pos']
axis : 'x' | 'y' | 'z'
Axis around which to rotate.
angles : array-like
Array of float of shape ``(batch_size,)`` containing the rotation
angles (in degrees) for each element of the input batch.
spherical_splines : bool
Whether to use spherical splines for the interpolation or not. When
`False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
used (as in the original paper).
References
----------
.. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
electroencephalographic data. In 2017 39th Annual International
Conference of the IEEE Engineering in Medicine and Biology Society
(EMBC) (pp. 471-474).
"""
rots = [
_make_rotation_matrix(axis, angle, degrees=True)
for angle in angles
]
rotated_X = _rotate_signals(
X, rots, sensors_positions_matrix, spherical_splines
)
return rotated_X, y
[docs]def mixup(X, y, lam, idx_perm):
"""Mixes two channels of EEG data.
See [1]_ for details.
Implementation based on [2]_.
Parameters
----------
X : torch.Tensor
EEG data in form ``batch_size, n_channels, n_times``
y : torch.Tensor
Target of length ``batch_size``
lam : torch.Tensor
Values between 0 and 1 setting the linear interpolation between
examples.
idx_perm: torch.Tensor
Permuted indices of example that are mixed into original examples.
Returns
-------
tuple
``X``, ``y``. Where ``X`` is augmented and ``y`` is a tuple of length
3 containing the labels of the two mixed channels and the mixing
coefficient.
References
----------
.. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
(2018). mixup: Beyond Empirical Risk Minimization. In 2018
International Conference on Learning Representations (ICLR)
Online: https://arxiv.org/abs/1710.09412
.. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
"""
device = X.device
batch_size, n_channels, n_times = X.shape
X_mix = torch.zeros((batch_size, n_channels, n_times)).to(device)
y_a = torch.arange(batch_size).to(device)
y_b = torch.arange(batch_size).to(device)
for idx in range(batch_size):
X_mix[idx] = lam[idx] * X[idx] \
+ (1 - lam[idx]) * X[idx_perm[idx]]
y_a[idx] = y[idx]
y_b[idx] = y[idx_perm[idx]]
return X_mix, (y_a, y_b, lam)