Source code for braindecode.models.deep4

# Authors: Robin Schirrmeister <robintibor@gmail.com>
#
# License: BSD (3-clause)

import numpy as np
from torch import nn
from torch.nn import init
from torch.nn.functional import elu

from .modules import Expression, AvgPool2dWithConv, Ensure4d
from .functions import identity, transpose_time_to_spat, squeeze_final_output
from ..util import np_to_th


[docs]class Deep4Net(nn.Sequential): """Deep ConvNet model from Schirrmeister et al 2017. Model described in [Schirrmeister2017]_. Parameters ---------- in_chans : int Number of EEG input channels. n_classes: int Number of classes to predict (number of output filters of last layer). input_window_samples: int | None Only used to determine the length of the last convolutional kernel if final_conv_length is "auto". final_conv_length: int | str Length of the final convolution layer. If set to "auto", input_window_samples must not be None. n_filters_time: int Number of temporal filters. n_filters_spat: int Number of spatial filters. filter_time_length: int Length of the temporal filter in layer 1. pool_time_length: int Length of temporal pooling filter. pool_time_stride: int Length of stride between temporal pooling filters. n_filters_2: int Number of temporal filters in layer 2. filter_length_2: int Length of the temporal filter in layer 2. n_filters_3: int Number of temporal filters in layer 3. filter_length_3: int Length of the temporal filter in layer 3. n_filters_4: int Number of temporal filters in layer 4. filter_length_4: int Length of the temporal filter in layer 4. first_conv_nonlin: callable Non-linear activation function to be used after convolution in layer 1. first_pool_mode: str Pooling mode in layer 1. "max" or "mean". first_pool_nonlin: callable Non-linear activation function to be used after pooling in layer 1. later_conv_nonlin: callable Non-linear activation function to be used after convolution in later layers. later_pool_mode: str Pooling mode in later layers. "max" or "mean". later_pool_nonlin: callable Non-linear activation function to be used after pooling in later layers. drop_prob: float Dropout probability. split_first_layer: bool Split first layer into temporal and spatial layers (True) or just use temporal (False). There would be no non-linearity between the split layers. batch_norm: bool Whether to use batch normalisation. batch_norm_alpha: float Momentum for BatchNorm2d. stride_before_pool: bool Stride before pooling. References ---------- .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730 """ def __init__( self, in_chans, n_classes, input_window_samples, final_conv_length, n_filters_time=25, n_filters_spat=25, filter_time_length=10, pool_time_length=3, pool_time_stride=3, n_filters_2=50, filter_length_2=10, n_filters_3=100, filter_length_3=10, n_filters_4=200, filter_length_4=10, first_conv_nonlin=elu, first_pool_mode="max", first_pool_nonlin=identity, later_conv_nonlin=elu, later_pool_mode="max", later_pool_nonlin=identity, drop_prob=0.5, split_first_layer=True, batch_norm=True, batch_norm_alpha=0.1, stride_before_pool=False, ): super().__init__() if final_conv_length == "auto": assert input_window_samples is not None self.in_chans = in_chans self.n_classes = n_classes self.input_window_samples = input_window_samples self.final_conv_length = final_conv_length self.n_filters_time = n_filters_time self.n_filters_spat = n_filters_spat self.filter_time_length = filter_time_length self.pool_time_length = pool_time_length self.pool_time_stride = pool_time_stride self.n_filters_2 = n_filters_2 self.filter_length_2 = filter_length_2 self.n_filters_3 = n_filters_3 self.filter_length_3 = filter_length_3 self.n_filters_4 = n_filters_4 self.filter_length_4 = filter_length_4 self.first_nonlin = first_conv_nonlin self.first_pool_mode = first_pool_mode self.first_pool_nonlin = first_pool_nonlin self.later_conv_nonlin = later_conv_nonlin self.later_pool_mode = later_pool_mode self.later_pool_nonlin = later_pool_nonlin self.drop_prob = drop_prob self.split_first_layer = split_first_layer self.batch_norm = batch_norm self.batch_norm_alpha = batch_norm_alpha self.stride_before_pool = stride_before_pool if self.stride_before_pool: conv_stride = self.pool_time_stride pool_stride = 1 else: conv_stride = 1 pool_stride = self.pool_time_stride self.add_module("ensuredims", Ensure4d()) pool_class_dict = dict(max=nn.MaxPool2d, mean=AvgPool2dWithConv) first_pool_class = pool_class_dict[self.first_pool_mode] later_pool_class = pool_class_dict[self.later_pool_mode] if self.split_first_layer: self.add_module("dimshuffle", Expression(transpose_time_to_spat)) self.add_module( "conv_time", nn.Conv2d( 1, self.n_filters_time, (self.filter_time_length, 1), stride=1, ), ) self.add_module( "conv_spat", nn.Conv2d( self.n_filters_time, self.n_filters_spat, (1, self.in_chans), stride=(conv_stride, 1), bias=not self.batch_norm, ), ) n_filters_conv = self.n_filters_spat else: self.add_module( "conv_time", nn.Conv2d( self.in_chans, self.n_filters_time, (self.filter_time_length, 1), stride=(conv_stride, 1), bias=not self.batch_norm, ), ) n_filters_conv = self.n_filters_time if self.batch_norm: self.add_module( "bnorm", nn.BatchNorm2d( n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5, ), ) self.add_module("conv_nonlin", Expression(self.first_nonlin)) self.add_module( "pool", first_pool_class( kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1) ), ) self.add_module("pool_nonlin", Expression(self.first_pool_nonlin)) def add_conv_pool_block( model, n_filters_before, n_filters, filter_length, block_nr ): suffix = "_{:d}".format(block_nr) self.add_module("drop" + suffix, nn.Dropout(p=self.drop_prob)) self.add_module( "conv" + suffix, nn.Conv2d( n_filters_before, n_filters, (filter_length, 1), stride=(conv_stride, 1), bias=not self.batch_norm, ), ) if self.batch_norm: self.add_module( "bnorm" + suffix, nn.BatchNorm2d( n_filters, momentum=self.batch_norm_alpha, affine=True, eps=1e-5, ), ) self.add_module("nonlin" + suffix, Expression(self.later_conv_nonlin)) self.add_module( "pool" + suffix, later_pool_class( kernel_size=(self.pool_time_length, 1), stride=(pool_stride, 1), ), ) self.add_module( "pool_nonlin" + suffix, Expression(self.later_pool_nonlin) ) add_conv_pool_block( self, n_filters_conv, self.n_filters_2, self.filter_length_2, 2 ) add_conv_pool_block( self, self.n_filters_2, self.n_filters_3, self.filter_length_3, 3 ) add_conv_pool_block( self, self.n_filters_3, self.n_filters_4, self.filter_length_4, 4 ) # self.add_module('drop_classifier', nn.Dropout(p=self.drop_prob)) self.eval() if self.final_conv_length == "auto": out = self( np_to_th( np.ones( (1, self.in_chans, self.input_window_samples, 1), dtype=np.float32, ) ) ) n_out_time = out.cpu().data.numpy().shape[2] self.final_conv_length = n_out_time self.add_module( "conv_classifier", nn.Conv2d( self.n_filters_4, self.n_classes, (self.final_conv_length, 1), bias=True, ), ) self.add_module("softmax", nn.LogSoftmax(dim=1)) self.add_module("squeeze", Expression(squeeze_final_output)) # Initialization, xavier is same as in our paper... # was default from lasagne init.xavier_uniform_(self.conv_time.weight, gain=1) # maybe no bias in case of no split layer and batch norm if self.split_first_layer or (not self.batch_norm): init.constant_(self.conv_time.bias, 0) if self.split_first_layer: init.xavier_uniform_(self.conv_spat.weight, gain=1) if not self.batch_norm: init.constant_(self.conv_spat.bias, 0) if self.batch_norm: init.constant_(self.bnorm.weight, 1) init.constant_(self.bnorm.bias, 0) param_dict = dict(list(self.named_parameters())) for block_nr in range(2, 5): conv_weight = param_dict["conv_{:d}.weight".format(block_nr)] init.xavier_uniform_(conv_weight, gain=1) if not self.batch_norm: conv_bias = param_dict["conv_{:d}.bias".format(block_nr)] init.constant_(conv_bias, 0) else: bnorm_weight = param_dict["bnorm_{:d}.weight".format(block_nr)] bnorm_bias = param_dict["bnorm_{:d}.bias".format(block_nr)] init.constant_(bnorm_weight, 1) init.constant_(bnorm_bias, 0) init.xavier_uniform_(self.conv_classifier.weight, gain=1) init.constant_(self.conv_classifier.bias, 0) # Start in eval mode self.eval()