braindecode.models.modules.FilterBankLayer#

class braindecode.models.modules.FilterBankLayer(n_chans: int, sfreq: int, band_filters: List[Tuple[float, float]] | int | None = None, method: str = 'fir', filter_length: str | float | int = 'auto', l_trans_bandwidth: str | float | int = 'auto', h_trans_bandwidth: str | float | int = 'auto', phase: str = 'zero', iir_params: dict | None = None, fir_window: str = 'hamming', fir_design: str = 'firwin', verbose: bool = True)[source]#

Apply multiple band-pass filters to generate multiview signal representation.

This layer constructs a bank of signals filtered in specific bands for each channel. It uses MNE’s create_filter function to create the band-specific filters and applies them to multi-channel time-series data. Each filter in the bank corresponds to a specific frequency band and is applied to all channels of the input data. The filtering is performed using FFT-based convolution via the fftconvolve function from torchaudio.functional if the method is FIR, and `filtfilt() function from :func:`torchaudio.functional if the method is IIR.

The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth, spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, …, 36-40 Hz). This setup is based on the reference: FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface.

Parameters:
  • n_chans (int) – Number of channels in the input signal.

  • sfreq (int) – Sampling frequency of the input signal in Hz.

  • band_filters (Optional[List[Tuple[float, float]]] or int, default=None) – List of frequency bands as (low_freq, high_freq) tuples. Each tuple defines the frequency range for one filter in the bank. If not provided, defaults to 9 non-overlapping bands with 4 Hz bandwidths spanning from 4 to 40 Hz.

  • method (str, default='fir') – 'fir' will use FIR filtering, 'iir' will use IIR forward-backward filtering (via filtfilt()). For more details, please check the MNE Preprocessing Tutorial.

  • filter_length (str | int) –

    Length of the FIR filter to use (if applicable):

    • ’auto’ (default): The filter length is chosen based on the size of the transition regions (6.6 times the reciprocal of the shortest transition band for fir_window=’hamming’ and fir_design=”firwin2”, and half that for “firwin”).

    • str: A human-readable time in units of “s” or “ms” (e.g., “10s” or “5500ms”) will be converted to that number of samples if phase="zero", or the shortest power-of-two length at least that duration for phase="zero-double".

    • int: Specified length in samples. For fir_design=”firwin”, this should not be used.

  • l_trans_bandwidth (Union[str, float, int], default='auto') –

    Width of the transition band at the low cut-off frequency in Hz (high pass or cutoff 1 in bandpass). Can be “auto” (default) to use a multiple of l_freq:

    min(max(l_freq * 0.25, 2), l_freq)
    

    Only used for method='fir'.

  • h_trans_bandwidth (Union[str, float, int], default='auto') –

    Width of the transition band at the high cut-off frequency in Hz (low pass or cutoff 2 in bandpass). Can be “auto” (default in 0.14) to use a multiple of h_freq:

    min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
    

    Only used for method='fir'.

  • phase (str, default='zero') –

    Phase of the filter. When method='fir', symmetric linear-phase FIR filters are constructed with the following behaviors when method="fir":

    "zero" (default)

    The delay of this filter is compensated for, making it non-causal.

    "minimum"

    A minimum-phase filter will be constructed by decomposing the zero-phase filter into a minimum-phase and all-pass systems, and then retaining only the minimum-phase system (of the same length as the original zero-phase filter) via scipy.signal.minimum_phase().

    "zero-double"

    This is a legacy option for compatibility with MNE <= 0.13. The filter is applied twice, once forward, and once backward (also making it non-causal).

    "minimum-half"

    This is a legacy option for compatibility with MNE <= 1.6. A minimum-phase filter will be reconstructed from the zero-phase filter with half the length of the original filter.

    When method='iir', phase='zero' (default) or equivalently 'zero-double' constructs and applies IIR filter twice, once forward, and once backward (making it non-causal) using filtfilt(); phase='forward' will apply the filter once in the forward (causal) direction using lfilter().

    The behavior for phase="minimum" was fixed to use a filter of the requested length and improved suppression.

  • iir_params (Optional[dict], default=None) – Dictionary of parameters to use for IIR filtering. If iir_params=None and method="iir", 4th order Butterworth will be used. For more information, see mne.filter.construct_iir_filter().

  • fir_window (str, default='hamming') – The window to use in FIR design, can be “hamming” (default), “hann” (default in 0.13), or “blackman”.

  • fir_design (str, default='firwin') – Can be “firwin” (default) to use scipy.signal.firwin(), or “firwin2” to use scipy.signal.firwin2(). “firwin” uses a time-domain design technique that generally gives improved attenuation using fewer samples than “firwin2”.

  • pad (str, default='reflect_limited') – The type of padding to use. Supports all func:numpy.pad() mode options. Can also be “reflect_limited”, which pads with a reflected version of each vector mirrored on the first and last values of the vector, followed by zeros. Only used for method='fir'.

  • verbose (bool | str | int | None, default=True) – Control verbosity of the logging output. If None, use the default verbosity level. See the func:mne.verbose for details. Should only be passed as a keyword argument.

Methods

forward(x: Tensor) Tensor[source]#

Apply the filter bank to the input signal.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, n_chans, time_points).

Returns:

Filtered output tensor of shape (batch_size, n_bands, n_chans, filtered_time_points).

Return type:

torch.Tensor