Source code for braindecode.models.shallow_fbcsp

# Authors: Robin Schirrmeister <robintibor@gmail.com>
#
# License: BSD (3-clause)

import numpy as np
from torch import nn
from torch.nn import init

from ..util import np_to_th
from .modules import Expression, Ensure4d
from .functions import (
    safe_log, square, transpose_time_to_spat, squeeze_final_output
)


[docs]class ShallowFBCSPNet(nn.Sequential): """Shallow ConvNet model from Schirrmeister et al 2017. Model described in [Schirrmeister2017]_. Parameters ---------- in_chans : int Number of EEG input channels. n_classes: int Number of classes to predict (number of output filters of last layer). input_window_samples: int | None Only used to determine the length of the last convolutional kernel if final_conv_length is "auto". n_filters_time: int Number of temporal filters. filter_time_length: int Length of the temporal filter. n_filters_spat: int Number of spatial filters. pool_time_length: int Length of temporal pooling filter. pool_time_stride: int Length of stride between temporal pooling filters. final_conv_length: int | str Length of the final convolution layer. If set to "auto", input_window_samples must not be None. conv_nonlin: callable Non-linear function to be used after convolution layers. pool_mode: str Method to use on pooling layers. "max" or "mean". pool_nonlin: callable Non-linear function to be used after pooling layers. split_first_layer: bool Split first layer into temporal and spatial layers (True) or just use temporal (False). There would be no non-linearity between the split layers. batch_norm: bool Whether to use batch normalisation. batch_norm_alpha: float Momentum for BatchNorm2d. drop_prob: float Dropout probability. References ---------- .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730 """ def __init__( self, in_chans, n_classes, input_window_samples=None, n_filters_time=40, filter_time_length=25, n_filters_spat=40, pool_time_length=75, pool_time_stride=15, final_conv_length=30, conv_nonlin=square, pool_mode="mean", pool_nonlin=safe_log, split_first_layer=True, batch_norm=True, batch_norm_alpha=0.1, drop_prob=0.5, ): super().__init__() if final_conv_length == "auto": assert input_window_samples is not None self.in_chans = in_chans self.n_classes = n_classes self.input_window_samples = input_window_samples self.n_filters_time = n_filters_time self.filter_time_length = filter_time_length self.n_filters_spat = n_filters_spat self.pool_time_length = pool_time_length self.pool_time_stride = pool_time_stride self.final_conv_length = final_conv_length self.conv_nonlin = conv_nonlin self.pool_mode = pool_mode self.pool_nonlin = pool_nonlin self.split_first_layer = split_first_layer self.batch_norm = batch_norm self.batch_norm_alpha = batch_norm_alpha self.drop_prob = drop_prob self.add_module("ensuredims", Ensure4d()) pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode] if self.split_first_layer: self.add_module("dimshuffle", Expression(transpose_time_to_spat)) self.add_module( "conv_time", nn.Conv2d( 1, self.n_filters_time, (self.filter_time_length, 1), stride=1, ), ) self.add_module( "conv_spat", nn.Conv2d( self.n_filters_time, self.n_filters_spat, (1, self.in_chans), stride=1, bias=not self.batch_norm, ), ) n_filters_conv = self.n_filters_spat else: self.add_module( "conv_time", nn.Conv2d( self.in_chans, self.n_filters_time, (self.filter_time_length, 1), stride=1, bias=not self.batch_norm, ), ) n_filters_conv = self.n_filters_time if self.batch_norm: self.add_module( "bnorm", nn.BatchNorm2d( n_filters_conv, momentum=self.batch_norm_alpha, affine=True ), ) self.add_module("conv_nonlin_exp", Expression(self.conv_nonlin)) self.add_module( "pool", pool_class( kernel_size=(self.pool_time_length, 1), stride=(self.pool_time_stride, 1), ), ) self.add_module("pool_nonlin_exp", Expression(self.pool_nonlin)) self.add_module("drop", nn.Dropout(p=self.drop_prob)) self.eval() if self.final_conv_length == "auto": out = self( np_to_th( np.ones( (1, self.in_chans, self.input_window_samples, 1), dtype=np.float32, ) ) ) n_out_time = out.cpu().data.numpy().shape[2] self.final_conv_length = n_out_time self.add_module( "conv_classifier", nn.Conv2d( n_filters_conv, self.n_classes, (self.final_conv_length, 1), bias=True, ), ) self.add_module("softmax", nn.LogSoftmax(dim=1)) self.add_module("squeeze", Expression(squeeze_final_output)) # Initialization, xavier is same as in paper... init.xavier_uniform_(self.conv_time.weight, gain=1) # maybe no bias in case of no split layer and batch norm if self.split_first_layer or (not self.batch_norm): init.constant_(self.conv_time.bias, 0) if self.split_first_layer: init.xavier_uniform_(self.conv_spat.weight, gain=1) if not self.batch_norm: init.constant_(self.conv_spat.bias, 0) if self.batch_norm: init.constant_(self.bnorm.weight, 1) init.constant_(self.bnorm.bias, 0) init.xavier_uniform_(self.conv_classifier.weight, gain=1) init.constant_(self.conv_classifier.bias, 0)