braindecode.models.IFNet#

class braindecode.models.IFNet(n_chans=None, n_outputs=None, n_times=None, chs_info=None, input_window_seconds=None, sfreq=None, bands: list[tuple[float, float]] | int | None = [(4.0, 16.0), (16, 40)], n_filters_spat: int = 64, kernel_sizes: tuple[int, int] = (63, 31), stride_factor: int = 8, drop_prob: float = 0.5, linear_max_norm: float = 0.5, activation: type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, verbose: bool = False, filter_parameters: dict | None = None)[source]#

IFNetV2 from Wang J et al (2023) [ifnet].

IFNetV2 Architecture

Overview of the Interactive Frequency Convolutional Neural Network architecture.

IFNetV2 is designed to effectively capture spectro-spatial-temporal features for motor imagery decoding from EEG data. The model consists of three stages: Spectro-Spatial Feature Representation, Cross-Frequency Interactions, and Classification.

  • Spectro-Spatial Feature Representation: The raw EEG signals are filtered into two characteristic frequency bands: low (4-16 Hz) and high (16-40 Hz), covering the most relevant motor imagery bands. Spectro-spatial features are then extracted through 1D point-wise spatial convolution followed by temporal convolution.

  • Cross-Frequency Interactions: The extracted spectro-spatial features from each frequency band are combined through an element-wise summation operation, which enhances feature representation while preserving distinct characteristics.

  • Classification: The aggregated spectro-spatial features are further reduced through temporal average pooling and passed through a fully connected layer followed by a softmax operation to generate output probabilities for each class.

Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • n_times (int) – Number of time samples of the input window.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • bands (list[tuple[int, int]] or int or None, default=[[4, 16], (16, 40)]) – Frequency bands for filtering.

  • n_filters_spat – The description is missing.

  • kernel_sizes (tuple of int, default=(63, 31)) – List of kernel sizes for temporal convolutions.

  • stride_factor – The description is missing.

  • drop_prob (float, default=0.5) – Dropout probability.

  • linear_max_norm – The description is missing.

  • activation (nn.Module, default=nn.GELU) – Activation function after the InterFrequency Layer.

  • verbose (bool, default=False) – Verbose to control the filtering layer

  • filter_parameters (dict, default={}) – Additional parameters for the filter bank layer.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

This implementation is not guaranteed to be correct, has not been checked by original authors, only reimplemented from the paper description and Torch source code [ifnetv2code]. Version 2 is present only in the repository, and the main difference is one pooling layer, describe at the TABLE VII from the paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10070810

References

[ifnet]

Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive frequency convolutional neural network for enhancing motor imagery decoding from EEG. IEEE Transactions on Neural Systems and Rehabilitation Engineering, 31, 1900-1911.

[ifnetv2code]

Wang, J., Yao, L., & Wang, Y. (2023). IFNet: An interactive frequency convolutional neural network for enhancing motor imagery decoding from EEG. Jiaheng-Wang/IFNet

Methods

forward(x: Tensor) Tensor[source]#

Forward pass of IFNet.

Parameters:

x (torch.Tensor) – Input tensor with shape (batch_size, n_chans, n_times).

Returns:

Output tensor with shape (batch_size, n_outputs).

Return type:

torch.Tensor