# Source code for braindecode.augmentation.base

# Authors: Cédric Rommel <cedric.rommel@inria.fr>
#          Alexandre Gramfort <alexandre.gramfort@inria.fr>
#          Bruno Aristimunha <b.aristimunha@gmail.com>
#          Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>

from typing import List, Tuple, Any
from numbers import Real

from sklearn.utils import check_random_state
import torch
from torch import Tensor, nn
from torch.utils.data._utils.collate import default_collate

from .functional import identity

Batch = List[Tuple[torch.Tensor, int, Any]]
Output = Tuple[torch.Tensor, torch.Tensor]

[docs]class Transform(torch.nn.Module): """Basic transform class used for implementing data augmentation operations. Parameters ---------- operation : callable A function taking arrays X, y (inputs and targets resp.) and other required arguments, and returning the transformed X and y. probability : float, optional Float between 0 and 1 defining the uniform probability of applying the operation. Set to 1.0 by default (e.g always apply the operation). random_state: int, optional Seed to be used to instatiate numpy random number generator instance. Used to decide whether or not to transform given the probability argument. Defaults to None. """ operation = None def __init__(self, probability=1.0, random_state=None): super().__init__() if self.forward.__func__ is Transform.forward: assert callable(self.operation),\ "operation should be a ``callable``." assert isinstance(probability, Real), ( f"probability should be a ``real``. Got {type(probability)}.") assert probability <= 1. and probability >= 0., \ "probability should be between 0 and 1." self._probability = probability self.rng = check_random_state(random_state)
[docs] def get_augmentation_params(self, *batch): return dict()