braindecode.modules package#
Submodules#
braindecode.modules.activation module#
- class braindecode.modules.activation.LogActivation(epsilon: float = 1e-06, *args, **kwargs)[source]#
Bases:
Module
Logarithm activation function.
- forward(x: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.activation.SafeLog(epsilon: float = 1e-06)[source]#
Bases:
Module
Safe logarithm activation function module.
:math:text{SafeLog}(x) = logleft(max(x, epsilon)right)
- Parameters:
eps (float, optional) – A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers. Default is 1e-6.
- extra_repr() str [source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x) Tensor [source]#
Forward pass of the SafeLog module.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Output tensor after applying safe logarithm.
- Return type:
braindecode.modules.attention module#
Attention modules used in the AttentionBaseNet from Martin Wimpff (2023).
Here, we implement some popular attention modules that can be used in the AttentionBaseNet class.
- class braindecode.modules.attention.CAT(in_channels: int, reduction_rate: int, kernel_size: int, bias=False)[source]#
Bases:
Module
Attention Mechanism from [Wu2023].
- Parameters:
References
[Wu2023]Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and Spatial Attention from Multi-Information Fusion. IET Computer Vision 2023.
- class braindecode.modules.attention.CATLite(in_channels: int, reduction_rate: int, bias: bool = True)[source]#
Bases:
Module
Modification of CAT without the convolutional layer from [Wu2023].
- Parameters:
References
[Wu2023]Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and Spatial Attention from Multi-Information Fusion. IET Computer Vision 2023.
- class braindecode.modules.attention.CBAM(in_channels: int, reduction_rate: int, kernel_size: int)[source]#
Bases:
Module
Convolutional Block Attention Module from [Woo2018].
- Parameters:
References
[Woo2018]Woo, S., Park, J., Lee, J., Kweon, I., 2018.
CBAM: Convolutional Block Attention Module. ECCV 2018.
- class braindecode.modules.attention.ECA(in_channels: int, kernel_size: int)[source]#
Bases:
Module
Efficient Channel Attention [Wang2021].
- Parameters:
References
[Wang2021]Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
for Deep Convolutional Neural Networks. CVPR 2021.
- class braindecode.modules.attention.EncNet(in_channels: int, n_codewords: int)[source]#
Bases:
Module
Context Encoding for Semantic Segmentation from [Zhang2018].
- Parameters:
References
[Zhang2018]Zhang, H. et al. 2018.
Context Encoding for Semantic Segmentation. CVPR 2018.
- class braindecode.modules.attention.FCA(in_channels, seq_len: int = 62, reduction_rate: int = 4, freq_idx: int = 0)[source]#
Bases:
Module
Frequency Channel Attention Networks from [Qin2021].
- Parameters:
References
[Qin2021]Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
FcaNet: Frequency Channel Attention Networks. ICCV 2021.
- class braindecode.modules.attention.GCT(in_channels: int)[source]#
Bases:
Module
Gated Channel Transformation from [Yang2020].
- Parameters:
in_channels (int) – number of input feature channels
References
[Yang2020]Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
Gated Channel Transformation for Visual Recognition. CVPR 2020.
- class braindecode.modules.attention.GSoP(in_channels: int, reduction_rate: int, bias: bool = True)[source]#
Bases:
Module
Global Second-order Pooling Convolutional Networks from [Gao2018].
- Parameters:
References
[Gao2018]Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
Global Second-order Pooling Convolutional Networks. CVPR 2018.
- class braindecode.modules.attention.GatherExcite(in_channels: int, seq_len: int = 62, extra_params: bool = False, use_mlp: bool = False, reduction_rate: int = 4)[source]#
Bases:
Module
Gather-Excite Networks from [Hu2018b].
- Parameters:
in_channels (int) – number of input feature channels
seq_len (int, default=62) – sequence length along temporal dimension
extra_params (bool, default=False) – whether to use a convolutional layer as a gather module
use_mlp (bool, default=False) – whether to use an excite block with fully-connected layers
reduction_rate (int, default=4) – reduction ratio of the excite block (if used)
References
[Hu2018b]Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks. NeurIPS 2018.
- class braindecode.modules.attention.MultiHeadAttention(emb_size, num_heads, dropout)[source]#
Bases:
Module
- forward(x: Tensor, mask: Tensor | None = None) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.attention.SRM(in_channels: int, use_mlp: bool = False, reduction_rate: int = 4, bias: bool = False)[source]#
Bases:
Module
Attention module from [Lee2019].
- Parameters:
References
[Lee2019]Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
Recalibration Module for Convolutional Neural Networks. ICCV 2019.
- class braindecode.modules.attention.SqueezeAndExcitation(in_channels: int, reduction_rate: int, bias: bool = False)[source]#
Bases:
Module
Squeeze-and-Excitation Networks from [Hu2018].
- Parameters:
References
[Hu2018]Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
Squeeze-and-Excitation Networks. CVPR 2018.
braindecode.modules.blocks module#
- class braindecode.modules.blocks.FeedForwardBlock(emb_size, expansion, drop_p, activation: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.GELU'>)[source]#
Bases:
Sequential
- class braindecode.modules.blocks.InceptionBlock(branches)[source]#
Bases:
Module
Inception block module.
This module applies multiple convolutional branches to the input and concatenates their outputs along the channel dimension. Each branch can have a different configuration, allowing the model to capture multi-scale features.
- Parameters:
branches (list of nn.Module) – List of convolutional branches to apply to the input.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.blocks.MLP(in_features: int, hidden_features=None, out_features=None, activation=<class 'torch.nn.modules.activation.GELU'>, drop=0.0, normalize=False)[source]#
Bases:
Sequential
Multilayer Perceptron (MLP) with GELU activation and optional dropout.
Also known as fully connected feedforward network, an MLP is a sequence of non-linear parametric functions
\[h_{i + 1} = a_{i + 1}(h_i W_{i + 1}^T + b_{i + 1}),\]over feature vectors \(h_i\), with the input and output feature vectors \(x = h_0\) and \(y = h_L\), respectively. The non-linear functions \(a_i\) are called activation functions. The trainable parameters of an MLP are its weights and biases \(\\phi = \{W_i, b_i | i = 1, \dots, L\}\).
braindecode.modules.convolution module#
- class braindecode.modules.convolution.AvgPool2dWithConv(kernel_size, stride, dilation=1, padding=0)[source]#
Bases:
Module
Compute average pooling using a convolution, to have the dilation parameter.
- Parameters:
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.convolution.CausalConv1d(in_channels, out_channels, kernel_size, dilation=1, **kwargs)[source]#
Bases:
Conv1d
Causal 1-dimensional convolution
Code modified from [1] and [2].
- Parameters:
in_channels (int) – Input channels.
out_channels (int) – Output channels (number of filters).
kernel_size (int) – Kernel size.
dilation (int, optional) – Dilation (number of elements to skip within kernel multiplication). Default to 1.
**kwargs – Other keyword arguments to pass to torch.nn.Conv1d, except for padding!!
References
- forward(X)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.convolution.CombinedConv(in_chans, n_filters_time=40, n_filters_spat=40, filter_time_length=25, bias_time=True, bias_spat=True)[source]#
Bases:
Module
Merged convolutional layer for temporal and spatial convs in Deep4/ShallowFBCSP
Numerically equivalent to the separate sequential approach, but this should be faster.
- Parameters:
in_chans (int) – Number of EEG input channels.
n_filters_time (int) – Number of temporal filters.
filter_time_length (int) – Length of the temporal filter.
n_filters_spat (int) – Number of spatial filters.
bias_time (bool) – Whether to use bias in the temporal conv
bias_spat (bool) – Whether to use bias in the spatial conv
- forward(x: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.convolution.Conv2dWithConstraint(*args, max_norm=1, **kwargs)[source]#
Bases:
Conv2d
- class braindecode.modules.convolution.DepthwiseConv2d(in_channels, depth_multiplier=2, kernel_size=3, stride=1, padding=0, dilation=1, bias=True, padding_mode='zeros')[source]#
Bases:
Conv2d
Depthwise convolution layer.
This class implements a depthwise convolution, where each input channel is convolved separately with its own filter (channel multiplier), effectively performing a spatial convolution independently over each channel.
- Parameters:
in_channels (int) – Number of channels in the input tensor.
depth_multiplier (int, optional) – Multiplier for the number of output channels. The total number of output channels will be in_channels * depth_multiplier. Default is 2.
kernel_size (int or tuple, optional) – Size of the convolutional kernel. Default is 3.
stride (int or tuple, optional) – Stride of the convolution. Default is 1.
padding (int or tuple, optional) – Padding added to both sides of the input. Default is 0.
dilation (int or tuple, optional) – Spacing between kernel elements. Default is 1.
bias (bool, optional) – If True, adds a learnable bias to the output. Default is True.
padding_mode (str, optional) – Padding mode to use. Options are ‘zeros’, ‘reflect’, ‘replicate’, or ‘circular’. Default is ‘zeros’.
braindecode.modules.filter module#
- class braindecode.modules.filter.FilterBankLayer(n_chans: int, sfreq: float, band_filters: list[tuple[float, float]] | int | None = None, method: str = 'fir', filter_length: str | float | int = 'auto', l_trans_bandwidth: str | float | int = 'auto', h_trans_bandwidth: str | float | int = 'auto', phase: str = 'zero', iir_params: dict | None = None, fir_window: str = 'hamming', fir_design: str = 'firwin', verbose: bool = True)[source]#
Bases:
Module
Apply multiple band-pass filters to generate multiview signal representation.
This layer constructs a bank of signals filtered in specific bands for each channel. It uses MNE’s create_filter function to create the band-specific filters and applies them to multi-channel time-series data. Each filter in the bank corresponds to a specific frequency band and is applied to all channels of the input data. The filtering is performed using FFT-based convolution via the fftconvolve function from
torchaudio.functional if the method is FIR, and `filtfilt()
function from :func:`torchaudio.functional if the method is IIR.The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth, spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, …, 36-40 Hz). This setup is based on the reference: FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface.
- Parameters:
n_chans (int) – Number of channels in the input signal.
sfreq (int) – Sampling frequency of the input signal in Hz.
band_filters (Optional[list[tuple[float, float]]] or int, default=None) – List of frequency bands as (low_freq, high_freq) tuples. Each tuple defines the frequency range for one filter in the bank. If not provided, defaults to 9 non-overlapping bands with 4 Hz bandwidths spanning from 4 to 40 Hz.
method (str, default='fir') –
'fir'
will use FIR filtering,'iir'
will use IIR forward-backward filtering (viafiltfilt()
). For more details, please check the MNE Preprocessing Tutorial.Length of the FIR filter to use (if applicable):
’auto’ (default): The filter length is chosen based on the size of the transition regions (6.6 times the reciprocal of the shortest transition band for fir_window=’hamming’ and fir_design=”firwin2”, and half that for “firwin”).
str: A human-readable time in units of “s” or “ms” (e.g., “10s” or “5500ms”) will be converted to that number of samples if
phase="zero"
, or the shortest power-of-two length at least that duration forphase="zero-double"
.int: Specified length in samples. For fir_design=”firwin”, this should not be used.
l_trans_bandwidth (Union[str, float, int], default='auto') –
Width of the transition band at the low cut-off frequency in Hz (high pass or cutoff 1 in bandpass). Can be “auto” (default) to use a multiple of
l_freq
:min(max(l_freq * 0.25, 2), l_freq)
Only used for
method='fir'
.h_trans_bandwidth (Union[str, float, int], default='auto') –
Width of the transition band at the high cut-off frequency in Hz (low pass or cutoff 2 in bandpass). Can be “auto” (default in 0.14) to use a multiple of
h_freq
:min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
Only used for
method='fir'
.phase (str, default='zero') –
Phase of the filter. When
method='fir'
, symmetric linear-phase FIR filters are constructed with the following behaviors whenmethod="fir"
:"zero"
(default)The delay of this filter is compensated for, making it non-causal.
"minimum"
A minimum-phase filter will be constructed by decomposing the zero-phase filter into a minimum-phase and all-pass systems, and then retaining only the minimum-phase system (of the same length as the original zero-phase filter) via
scipy.signal.minimum_phase()
."zero-double"
This is a legacy option for compatibility with MNE <= 0.13. The filter is applied twice, once forward, and once backward (also making it non-causal).
"minimum-half"
This is a legacy option for compatibility with MNE <= 1.6. A minimum-phase filter will be reconstructed from the zero-phase filter with half the length of the original filter.
When
method='iir'
,phase='zero'
(default) or equivalently'zero-double'
constructs and applies IIR filter twice, once forward, and once backward (making it non-causal) usingfiltfilt()
;phase='forward'
will apply the filter once in the forward (causal) direction usinglfilter()
.The behavior for
phase="minimum"
was fixed to use a filter of the requested length and improved suppression.iir_params (Optional[dict], default=None) – Dictionary of parameters to use for IIR filtering. If
iir_params=None
andmethod="iir"
, 4th order Butterworth will be used. For more information, seemne.filter.construct_iir_filter()
.fir_window (str, default='hamming') – The window to use in FIR design, can be “hamming” (default), “hann” (default in 0.13), or “blackman”.
fir_design (str, default='firwin') – Can be “firwin” (default) to use
scipy.signal.firwin()
, or “firwin2” to usescipy.signal.firwin2()
. “firwin” uses a time-domain design technique that generally gives improved attenuation using fewer samples than “firwin2”.pad (str, default='reflect_limited') – The type of padding to use. Supports all func:numpy.pad() mode options. Can also be “reflect_limited”, which pads with a reflected version of each vector mirrored on the first and last values of the vector, followed by zeros. Only used for
method='fir'
.verbose (bool | str | int | None, default=True) – Control verbosity of the logging output. If
None
, use the default verbosity level. See the func:mne.verbose for details. Should only be passed as a keyword argument.
- forward(x: Tensor) Tensor [source]#
Apply the filter bank to the input signal.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, n_chans, time_points).
- Returns:
Filtered output tensor of shape (batch_size, n_bands, n_chans, filtered_time_points).
- Return type:
- class braindecode.modules.filter.GeneralizedGaussianFilter(in_channels, out_channels, sequence_length, sample_rate, inverse_fourier=True, affine_group_delay=False, group_delay=(20.0,), f_mean=(23.0,), bandwidth=(44.0,), shape=(2.0,), clamp_f_mean=(1.0, 45.0))[source]#
Bases:
Module
Generalized Gaussian Filter from Ludwig et al (2024) [eegminer].
Implements trainable temporal filters based on generalized Gaussian functions in the frequency domain.
This module creates filters in the frequency domain using the generalized Gaussian function, allowing for trainable center frequency (f_mean), bandwidth (bandwidth), and shape (shape) parameters.
The filters are applied to the input signal in the frequency domain, and can be optionally transformed back to the time domain using the inverse Fourier transform.
The generalized Gaussian function in the frequency domain is defined as:
\[F(x) = \exp\left( - \left( \frac{abs(x - \mu)}{\alpha} \right)^{\beta} \right)\]- where:
μ (mu) is the center frequency (f_mean).
α (alpha) is the scale parameter, reparameterized in terms of the full width at half maximum (FWHM) h as:
\[\alpha = \frac{h}{2 \left( \ln(2) \right)^{1/\beta}}\]β (beta) is the shape parameter (shape), controlling the shape of the filter.
The filters are constructed in the frequency domain to allow full control over the magnitude and phase responses.
A linear phase response is used, with an optional trainable group delay (group_delay).
Copyright (C) Cogitat, Ltd.
Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
Patent GB2609265 - Learnable filters for eeg classification
https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels. Must be a multiple of in_channels.
sequence_length (int) – Length of the input sequences (time steps).
sample_rate (float) – Sampling rate of the input signals in Hz.
inverse_fourier (bool, optional) – If True, applies the inverse Fourier transform to return to the time domain after filtering. Default is True.
affine_group_delay (bool, optional) – If True, makes the group delay parameter trainable. Default is False.
group_delay (tuple of float, optional) – Initial group delay(s) in milliseconds for the filters. Default is (20.0,).
f_mean (tuple of float, optional) – Initial center frequency (frequencies) of the filters in Hz. Default is (23.0,).
bandwidth (tuple of float, optional) – Initial bandwidth(s) (full width at half maximum) of the filters in Hz. Default is (44.0,).
shape (tuple of float, optional) – Initial shape parameter(s) of the generalized Gaussian filters. Must be >= 2.0. Default is (2.0,).
clamp_f_mean (tuple of float, optional) – Minimum and maximum allowable values for the center frequency f_mean in Hz. Specified as (min_f_mean, max_f_mean). Default is (1.0, 45.0).
Notes
The model and the module have a patent [eegminercode], and the code is CC BY-NC 4.0.
Added in version 0.9.
References
[eegminer]Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis, Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features of brain activity with learnable filters. Journal of Neural Engineering, 21(3), 036010.
[eegminercode]Ludwig, S., Bakas, S., Adamos, D. A., Laskaris, N., Panagakis, Y., & Zafeiriou, S. (2024). EEGMiner: discovering interpretable features of brain activity with learnable filters. SMLudwig/EEGminer. Cogitat, Ltd. “Learnable filters for EEG classification.” Patent GB2609265. https://www.ipo.gov.uk/p-ipsum/Case/ApplicationNumber/GB2113420.0
- construct_filters()[source]#
Constructs the filters in the frequency domain based on current parameters.
- Returns:
The constructed filters with shape (out_channels, freq_bins, 2).
- Return type:
- static exponential_power(x, mean, fwhm, shape)[source]#
Computes the generalized Gaussian function:
\[F(x) = \exp\left( - \left( \frac{|x - \mu|}{\alpha} \right)^{\beta} \right)\]where:
\(\mu\) is the mean (mean).
\(\alpha\) is the scale parameter, reparameterized using the FWHM \(h\) as:
\[\alpha = \frac{h}{2 \left( \ln(2) \right)^{1/\beta}}\]\(\beta\) is the shape parameter (shape).
- Parameters:
x (torch.Tensor) – The input tensor representing frequencies, normalized between 0 and 1.
mean (torch.Tensor) – The center frequency (f_mean), normalized between 0 and 1.
fwhm (torch.Tensor) – The full width at half maximum (bandwidth), normalized between 0 and 1.
shape (torch.Tensor) – The shape parameter (shape) of the generalized Gaussian.
- Returns:
The computed generalized Gaussian function values at frequencies x.
- Return type:
- forward(x)[source]#
Applies the generalized Gaussian filters to the input signal.
- Parameters:
x (torch.Tensor) – Input tensor of shape (…, in_channels, sequence_length).
- Returns:
The filtered signal. If inverse_fourier is True, returns the signal in the time domain with shape (…, out_channels, sequence_length). Otherwise, returns the signal in the frequency domain with shape (…, out_channels, freq_bins, 2).
- Return type:
braindecode.modules.layers module#
- class braindecode.modules.layers.Chomp1d(chomp_size)[source]#
Bases:
Module
- extra_repr()[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.layers.DropPath(drop_prob=None)[source]#
Bases:
Module
Drop paths, also known as Stochastic Depth, per sample.
When applied in main path of residual blocks.
- facebookresearch/vissl
All rights reserved.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- extra_repr() str [source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.layers.Ensure4d(*args, **kwargs)[source]#
Bases:
Module
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.layers.SqueezeFinalOutput[source]#
Bases:
Module
Removes empty dimension at end and potentially removes empty time dimension. It does not just use squeeze as we never want to remove first dimension.
- Returns:
x – squeezed tensor
- Return type:
- forward(x: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.layers.TimeDistributed(module)[source]#
Bases:
Module
Apply module on multiple windows.
Apply the provided module on a sequence of windows and return their concatenation. Useful with sequence-to-prediction models (e.g. sleep stager which must map a sequence of consecutive windows to the label of the middle window in the sequence).
- Parameters:
module (nn.Module) – Module to be applied to the input windows. Must accept an input of shape (batch_size, n_channels, n_times).
- forward(x)[source]#
- Parameters:
x (torch.Tensor) – Sequence of windows, of shape (batch_size, seq_len, n_channels, n_times).
- Returns:
Shape (batch_size, seq_len, output_size).
- Return type:
braindecode.modules.linear module#
- class braindecode.modules.linear.LinearWithConstraint(*args, max_norm=1.0, **kwargs)[source]#
Bases:
Linear
Linear layer with max-norm constraint on the weights.
- class braindecode.modules.linear.MaxNormLinear(in_features, out_features, bias=True, max_norm_val=2, eps=1e-05, **kwargs)[source]#
Bases:
Linear
Linear layer with MaxNorm constraining on weights.
Equivalent of Keras tf.keras.Dense(…, kernel_constraint=max_norm()) [1] and [2]. Implemented as advised in [3].
- Parameters:
References
braindecode.modules.parametrization module#
- class braindecode.modules.parametrization.MaxNorm(max_norm_val=2.0, eps=1e-05)[source]#
Bases:
Module
- forward(X: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.parametrization.MaxNormParametrize(max_norm: float = 1.0)[source]#
Bases:
Module
Enforce a max‑norm constraint on the rows of a weight tensor via parametrization.
- forward(X: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
braindecode.modules.stats module#
- class braindecode.modules.stats.StatLayer(stat_fn: Callable[[...], Tensor], dim: int, keepdim: bool = True, clamp_range: tuple[float, float] | None = None, apply_log: bool = False)[source]#
Bases:
Module
Generic layer to compute a statistical function along a specified dimension. :param stat_fn: A function like torch.mean, torch.std, etc. :type stat_fn: Callable :param dim: Dimension along which to apply the function. :type dim: int :param keepdim: Whether to keep the reduced dimension. :type keepdim: bool, default=True :param clamp_range: Used only for functions requiring clamping (e.g., log variance). :type clamp_range: tuple(float, float), optional :param apply_log: Whether to apply log after computation (used for LogVarLayer). :type apply_log: bool, default=False
- forward(x: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
braindecode.modules.util module#
- braindecode.modules.util.aggregate_probas(logits, n_windows_stride=1)[source]#
Aggregate predicted probabilities with self-ensembling.
Aggregate window-wise predicted probabilities obtained on overlapping sequences of windows using multiplicative voting as described in [Phan2018].
- Parameters:
logits (np.ndarray) – Array of shape (n_sequences, n_classes, n_windows) containing the logits (i.e. the raw unnormalized scores for each class) for each window of each sequence.
n_windows_stride (int) – Number of windows between two consecutive sequences. Default is 1 (maximally overlapping sequences).
- Returns:
Array of shape ((n_rows - 1) * stride + n_windows, n_classes) containing the aggregated predicted probabilities for each window contained in the input sequences.
- Return type:
np.ndarray
References
[Phan2018]Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., & De Vos, M. (2018). Joint classification and prediction CNN framework for automatic sleep stage classification. IEEE Transactions on Biomedical Engineering, 66(5), 1285-1296.
braindecode.modules.wrapper module#
- class braindecode.modules.wrapper.Expression(expression_fn)[source]#
Bases:
Module
Compute given expression on forward pass.
- Parameters:
expression_fn (callable) – Should accept variable number of objects of type torch.autograd.Variable to compute its output.
- forward(x: Tensor) Tensor [source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class braindecode.modules.wrapper.IntermediateOutputWrapper(to_select, model)[source]#
Bases:
Module
Wraps network model such that outputs of intermediate layers can be returned. forward() returns list of intermediate activations in a network during forward pass.
- Parameters:
to_select (list) – list of module names for which activation should be returned
model (model object) – network model
Examples
>>> model = Deep4Net() >>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs >>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.