Source code for braindecode.samplers.ssl

"""
Self-supervised learning samplers.
"""

# Authors: Hubert Banville <hubert.jbanville@gmail.com>
#
# License: BSD (3-clause)

import numpy as np

from . import RecordingSampler


[docs]class RelativePositioningSampler(RecordingSampler): """Sample examples for the relative positioning task from [Banville2020]_. Sample examples as tuples of two window indices, with a label indicating whether the windows are close or far, as defined by tau_pos and tau_neg. Parameters ---------- metadata : pd.DataFrame See RecordingSampler. tau_pos : int Size of the positive context, in samples. A positive pair contains two windows x1 and x2 which are separated by at most `tau_pos` samples. tau_neg : int Size of the negative context, in samples. A negative pair contains two windows x1 and x2 which are separated by at least `tau_neg` samples and at most `tau_max` samples. Ignored if `same_rec_neg` is False. n_examples : int Number of pairs to extract. tau_max : int | None See `tau_neg`. same_rec_neg : bool If True, sample negative pairs from within the same recording. If False, sample negative pairs from two different recordings. random_state : None | np.RandomState | int Random state. References ---------- .. [Banville2020] Banville, H., Chehab, O., Hyvärinen, A., Engemann, D. A., & Gramfort, A. (2020). Uncovering the structure of clinical EEG signals with self-supervised learning. arXiv preprint arXiv:2007.16104. """ def __init__(self, metadata, tau_pos, tau_neg, n_examples, tau_max=None, same_rec_neg=True, random_state=None): super().__init__(metadata, random_state=random_state) self.tau_pos = tau_pos self.tau_neg = tau_neg self.tau_max = np.inf if tau_max is None else tau_max self.n_examples = n_examples self.same_rec_neg = same_rec_neg if not same_rec_neg and self.n_recordings < 2: raise ValueError('More than one recording must be available when ' 'using across-recording negative sampling.') def _sample_pair(self): """Sample a pair of two windows. """ # Sample first window win_ind1, rec_ind1 = self.sample_window() ts1 = self.metadata.iloc[win_ind1]['i_start_in_trial'] ts = self.info.iloc[rec_ind1]['i_start_in_trial'] # Decide whether the pair will be positive or negative pair_type = self.rng.binomial(1, 0.5) win_ind2 = None if pair_type == 0: # Negative example if self.same_rec_neg: mask = ( ((ts <= ts1 - self.tau_neg) & (ts >= ts1 - self.tau_max)) | ((ts >= ts1 + self.tau_neg) & (ts <= ts1 + self.tau_max)) ) else: rec_ind2 = rec_ind1 while rec_ind2 == rec_ind1: win_ind2, rec_ind2 = self.sample_window() elif pair_type == 1: # Positive example mask = (ts >= ts1 - self.tau_pos) & (ts <= ts1 + self.tau_pos) if win_ind2 is None: mask[ts == ts1] = False # same window cannot be sampled twice if sum(mask) == 0: raise NotImplementedError win_ind2 = self.rng.choice(self.info.iloc[rec_ind1]['index'][mask]) return win_ind1, win_ind2, float(pair_type)
[docs] def presample(self): """Presample examples. Once presampled, the examples are the same from one epoch to another. """ self.examples = [self._sample_pair() for _ in range(self.n_examples)] return self
def __iter__(self): """Iterate over pairs. Yields ------ (int): position of the first window in the dataset. (int): position of the second window in the dataset. (float): 0 for negative pair, 1 for positive pair. """ for i in range(self.n_examples): if hasattr(self, 'examples'): yield self.examples[i] else: yield self._sample_pair() def __len__(self): return self.n_examples