braindecode.models.SincShallowNet#
- class braindecode.models.SincShallowNet(num_time_filters: int = 32, time_filter_len: int = 33, depth_multiplier: int = 2, activation: ~torch.nn.modules.module.Module | None = <class 'torch.nn.modules.activation.ELU'>, drop_prob: float = 0.5, first_freq: float = 5.0, min_freq: float = 1.0, freq_stride: float = 1.0, padding: str = 'same', bandwidth: float = 4.0, pool_size: int = 55, pool_stride: int = 12, n_chans: int | None = None, n_outputs: int | None = None, n_times: int | None = None, input_window_seconds: float | None = None, sfreq: float | None = None, chs_info: ~typing.List[~typing.Dict] | None = None)[source]#
Sinc-ShallowNet from Borra, D et al (2020) [borra2020].
The Sinc-ShallowNet architecture has these fundamental blocks:
- Block 1: Spectral and Spatial Feature Extraction
- Temporal Sinc-Convolutional Layer:
Uses parametrized sinc functions to learn band-pass filters, significantly reducing the number of trainable parameters by only learning the lower and upper cutoff frequencies for each filter.
- Spatial Depthwise Convolutional Layer:
Applies depthwise convolutions to learn spatial filters for each temporal feature map independently, further reducing parameters and enhancing interpretability.
Batch Normalization
- Block 2: Temporal Aggregation
Activation Function: ELU
Average Pooling Layer: Aggregation by averaging spatial dim
Dropout Layer
Flatten Layer
- Block 3: Classification
Fully Connected Layer: Maps the feature vector to n_outputs.
Implementation Notes:
- The sinc-convolutional layer initializes cutoff frequencies uniformly
within the desired frequency range and updates them during training while ensuring the lower cutoff is less than the upper cutoff.
- Parameters:
num_time_filters (int) – Number of temporal filters in the SincFilter layer.
time_filter_len (int) – Size of the temporal filters.
depth_multiplier (int) – Depth multiplier for spatial filtering.
activation (nn.Module, optional) – Activation function to use. Default is nn.ELU().
drop_prob (float, optional) – Dropout probability. Default is 0.5.
first_freq (float, optional) – The starting frequency for the first Sinc filter. Default is 5.0.
min_freq (float, optional) – Minimum frequency allowed for the low frequencies of the filters. Default is 1.0.
freq_stride (float, optional) – Frequency stride for the Sinc filters. Controls the spacing between the filter frequencies. Default is 1.0.
padding (str, optional) – Padding mode for convolution, either ‘same’ or ‘valid’. Default is ‘same’.
bandwidth (float, optional) – Initial bandwidth for each Sinc filter. Default is 4.0.
pool_size (int, optional) – Size of the pooling window for the average pooling layer. Default is 55.
pool_stride (int, optional) – Stride of the pooling operation. Default is 12.
n_chans (int) – Number of EEG channels.
n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.
n_times (int) – Number of time samples of the input window.
input_window_seconds (float) – Length of the input window in seconds.
sfreq (float) – Sampling frequency of the EEG recordings.
chs_info (list of dict) – Information about each individual EEG channel. This should be filled with
info["chs"]
. Refer tomne.Info
for more details.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
FutureWarning – If add_log_softmax is True, since LogSoftmax final layer: will be removed in the future.
Notes
This implementation is based on the implementation from [sincshallowcode].
References
[borra2020]Borra, D., Fantozzi, S., & Magosso, E. (2020). Interpretable and lightweight convolutional neural network for EEG decoding: Application to movement execution and imagination. Neural Networks, 129, 55-74.
[sincshallowcode]Sinc-ShallowNet re-implementation source code: marcellosicbaldi/SincNet-Tensorflow
Methods
- forward(x: Tensor) Tensor [source]#
Forward pass of the model.
- Parameters:
x (torch.Tensor) – Input tensor of shape [batch_size, num_channels, num_samples].
- Returns:
Output logits of shape [batch_size, num_classes].
- Return type: