braindecode.models.modules.FilterBankLayer#
- class braindecode.models.modules.FilterBankLayer(n_chans: int, sfreq: int, band_filters: List[Tuple[float, float]] | int | None = None, method: str = 'fir', filter_length: str | float | int = 'auto', l_trans_bandwidth: str | float | int = 'auto', h_trans_bandwidth: str | float | int = 'auto', phase: str = 'zero', iir_params: dict | None = None, fir_window: str = 'hamming', fir_design: str = 'firwin', verbose: bool = True)[source]#
Apply multiple band-pass filters to generate multiview signal representation.
This layer constructs a bank of signals filtered in specific bands for each channel. It uses MNE’s create_filter function to create the band-specific filters and applies them to multi-channel time-series data. Each filter in the bank corresponds to a specific frequency band and is applied to all channels of the input data. The filtering is performed using FFT-based convolution via the fftconvolve function from
torchaudio.functional if the method is FIR, and `filtfilt()
function from :func:`torchaudio.functional if the method is IIR.The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth, spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, …, 36-40 Hz). This setup is based on the reference: FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface.
- Parameters:
n_chans (int) – Number of channels in the input signal.
sfreq (int) – Sampling frequency of the input signal in Hz.
band_filters (Optional[List[Tuple[float, float]]] or int, default=None) – List of frequency bands as (low_freq, high_freq) tuples. Each tuple defines the frequency range for one filter in the bank. If not provided, defaults to 9 non-overlapping bands with 4 Hz bandwidths spanning from 4 to 40 Hz.
method (str, default='fir') –
'fir'
will use FIR filtering,'iir'
will use IIR forward-backward filtering (viafiltfilt()
). For more details, please check the MNE Preprocessing Tutorial.Length of the FIR filter to use (if applicable):
’auto’ (default): The filter length is chosen based on the size of the transition regions (6.6 times the reciprocal of the shortest transition band for fir_window=’hamming’ and fir_design=”firwin2”, and half that for “firwin”).
str: A human-readable time in units of “s” or “ms” (e.g., “10s” or “5500ms”) will be converted to that number of samples if
phase="zero"
, or the shortest power-of-two length at least that duration forphase="zero-double"
.int: Specified length in samples. For fir_design=”firwin”, this should not be used.
l_trans_bandwidth (Union[str, float, int], default='auto') –
Width of the transition band at the low cut-off frequency in Hz (high pass or cutoff 1 in bandpass). Can be “auto” (default) to use a multiple of
l_freq
:min(max(l_freq * 0.25, 2), l_freq)
Only used for
method='fir'
.h_trans_bandwidth (Union[str, float, int], default='auto') –
Width of the transition band at the high cut-off frequency in Hz (low pass or cutoff 2 in bandpass). Can be “auto” (default in 0.14) to use a multiple of
h_freq
:min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq)
Only used for
method='fir'
.phase (str, default='zero') –
Phase of the filter. When
method='fir'
, symmetric linear-phase FIR filters are constructed with the following behaviors whenmethod="fir"
:"zero"
(default)The delay of this filter is compensated for, making it non-causal.
"minimum"
A minimum-phase filter will be constructed by decomposing the zero-phase filter into a minimum-phase and all-pass systems, and then retaining only the minimum-phase system (of the same length as the original zero-phase filter) via
scipy.signal.minimum_phase()
."zero-double"
This is a legacy option for compatibility with MNE <= 0.13. The filter is applied twice, once forward, and once backward (also making it non-causal).
"minimum-half"
This is a legacy option for compatibility with MNE <= 1.6. A minimum-phase filter will be reconstructed from the zero-phase filter with half the length of the original filter.
When
method='iir'
,phase='zero'
(default) or equivalently'zero-double'
constructs and applies IIR filter twice, once forward, and once backward (making it non-causal) usingfiltfilt()
;phase='forward'
will apply the filter once in the forward (causal) direction usinglfilter()
.The behavior for
phase="minimum"
was fixed to use a filter of the requested length and improved suppression.iir_params (Optional[dict], default=None) – Dictionary of parameters to use for IIR filtering. If
iir_params=None
andmethod="iir"
, 4th order Butterworth will be used. For more information, seemne.filter.construct_iir_filter()
.fir_window (str, default='hamming') – The window to use in FIR design, can be “hamming” (default), “hann” (default in 0.13), or “blackman”.
fir_design (str, default='firwin') – Can be “firwin” (default) to use
scipy.signal.firwin()
, or “firwin2” to usescipy.signal.firwin2()
. “firwin” uses a time-domain design technique that generally gives improved attenuation using fewer samples than “firwin2”.pad (str, default='reflect_limited') – The type of padding to use. Supports all func:numpy.pad() mode options. Can also be “reflect_limited”, which pads with a reflected version of each vector mirrored on the first and last values of the vector, followed by zeros. Only used for
method='fir'
.verbose (bool | str | int | None, default=True) – Control verbosity of the logging output. If
None
, use the default verbosity level. See the func:mne.verbose for details. Should only be passed as a keyword argument.
Methods
- forward(x: Tensor) Tensor [source]#
Apply the filter bank to the input signal.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, n_chans, time_points).
- Returns:
Filtered output tensor of shape (batch_size, n_bands, n_chans, filtered_time_points).
- Return type: