braindecode.modules.StatLayer#
- class braindecode.modules.StatLayer(stat_fn: Callable[[...], Tensor], dim: int, keepdim: bool = True, clamp_range: tuple[float, float] | None = None, apply_log: bool = False)[source]#
Generic layer to compute a statistical function along a specified dimension. :param stat_fn: A function like torch.mean, torch.std, etc. :type stat_fn: Callable :param dim: Dimension along which to apply the function. :type dim: int :param keepdim: Whether to keep the reduced dimension. :type keepdim: bool, default=True :param clamp_range: Used only for functions requiring clamping (e.g., log variance). :type clamp_range: tuple(float, float), optional :param apply_log: Whether to apply log after computation (used for LogVarLayer). :type apply_log: bool, default=False
Methods
- forward(x: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.