braindecode.models.FBLightConvNet#

class braindecode.models.FBLightConvNet(n_chans=None, n_outputs=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, n_bands=9, n_filters_spat=32, n_dim=3, stride_factor=4, win_len=250, heads=8, weight_softmax=True, bias=False, activation=<class 'torch.nn.modules.activation.ELU'>, verbose=False, filter_parameters=None)[source]#

LightConvNet from Ma, X et al (2023) [lightconvnet].

Convolution Filterbank

LightConvNet Neural Network

A lightweight convolutional neural network incorporating temporal dependency learning and attention mechanisms. The architecture is designed to efficiently capture spatial and temporal features through specialized convolutional layers and multi-head attention.

The network architecture consists of four main modules:

  1. Spatial and Spectral Information Learning:

    Applies filterbank and spatial convolutions. This module is followed by batch normalization and an activation function to enhance feature representation.

  2. Temporal Segmentation and Feature Extraction:

    Divides the processed data into non-overlapping temporal windows. Within each window, a variance-based layer extracts discriminative features, which are then log-transformed to stabilize variance before being passed to the attention module.

  3. Temporal Attention Module: Utilizes a multi-head attention

    mechanism with depthwise separable convolutions to capture dependencies across different temporal segments. The attention weights are normalized using softmax and aggregated to form a comprehensive temporal representation.

  4. Final Layer: Flattens the aggregated features and passes them

    through a linear layer to with kernel sizes matching the input dimensions to integrate features across different channels generate the final output predictions.

Notes

This implementation is not guaranteed to be correct and has not been checked by the original authors; it is a braindecode adaptation from the Pytorch source-code [lightconvnetcode].

Parameters:
  • n_bands (int or None or list of tuple of int, default=8) – Number of frequency bands or a list of frequency band tuples. If a list of tuples is provided, each tuple defines the lower and upper bounds of a frequency band.

  • n_filters_spat (int, default=32) – Number of spatial filters in the depthwise convolutional layer.

  • n_dim (int, default=3) – Number of dimensions for the temporal reduction layer.

  • stride_factor (int, default=4) – Stride factor used for reshaping the temporal dimension.

  • activation (nn.Module, default=nn.ELU) – Activation function class to apply after convolutional layers.

  • verbose (bool, default=False) – If True, enables verbose output during filter creation using mne.

  • filter_parameters (dict, default={}) – Additional parameters for the FilterBankLayer.

  • heads (int, default=8) – Number of attention heads in the multi-head attention mechanism.

  • weight_softmax (bool, default=True) – If True, applies softmax to the attention weights.

  • bias (bool, default=False) – If True, includes a bias term in the convolutional layers.

References

[lightconvnet]

Ma, X., Chen, W., Pei, Z., Liu, J., Huang, B., & Chen, J. (2023). A temporal dependency learning CNN with attention mechanism for MI-EEG decoding. IEEE Transactions on Neural Systems and Rehabilitation Engineering.

[lightconvnetcode]

Link to source-code: Ma-Xinzhi/LightConvNet

Methods

forward(x)[source]#

Forward pass of the FBLightConvNet model. :type x: Tensor :param x: Input tensor with shape (batch_size, n_chans, n_times). :type x: torch.Tensor

Returns:

Output tensor with shape (batch_size, n_outputs).

Return type:

torch.Tensor