# Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
# Tonio Ball
#
# License: BSD-3
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.functional import elu
from .functions import transpose_time_to_spat, squeeze_final_output
from .modules import Expression, AvgPool2dWithConv, Ensure4d
[docs]class EEGResNet(nn.Sequential):
"""Residual Network for EEG.
XXX missing reference
Parameters
----------
in_chans : int
XXX
"""
def __init__(self,
in_chans,
n_classes,
input_window_samples,
final_pool_length,
n_first_filters,
n_layers_per_block=2,
first_filter_length=3,
nonlinearity=elu,
split_first_layer=True,
batch_norm_alpha=0.1,
batch_norm_epsilon=1e-4,
conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0)):
super().__init__()
self.in_chans = in_chans
self.n_classes = n_classes
self.input_window_samples = input_window_samples
if final_pool_length == 'auto':
assert input_window_samples is not None
assert first_filter_length % 2 == 1
self.final_pool_length = final_pool_length
self.n_first_filters = n_first_filters
self.n_layers_per_block = n_layers_per_block
self.first_filter_length = first_filter_length
self.nonlinearity = nonlinearity
self.split_first_layer = split_first_layer
self.batch_norm_alpha = batch_norm_alpha
self.batch_norm_epsilon = batch_norm_epsilon
self.conv_weight_init_fn = conv_weight_init_fn
self.add_module("ensuredims", Ensure4d())
if self.split_first_layer:
self.add_module('dimshuffle', Expression(transpose_time_to_spat))
self.add_module('conv_time', nn.Conv2d(1, self.n_first_filters,
(self.first_filter_length, 1),
stride=1,
padding=(self.first_filter_length // 2, 0)))
self.add_module('conv_spat',
nn.Conv2d(self.n_first_filters, self.n_first_filters,
(1, self.in_chans),
stride=(1, 1),
bias=False))
else:
self.add_module('conv_time',
nn.Conv2d(self.in_chans, self.n_first_filters,
(self.first_filter_length, 1),
stride=(1, 1),
padding=(self.first_filter_length // 2, 0),
bias=False,))
n_filters_conv = self.n_first_filters
self.add_module('bnorm',
nn.BatchNorm2d(n_filters_conv,
momentum=self.batch_norm_alpha,
affine=True,
eps=1e-5),)
self.add_module('conv_nonlin', Expression(self.nonlinearity))
cur_dilation = np.array([1, 1])
n_cur_filters = n_filters_conv
i_block = 1
for i_layer in range(self.n_layers_per_block):
self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation))
i_block += 1
cur_dilation[0] *= 2
n_out_filters = int(2 * n_cur_filters)
self.add_module('res_{:d}_{:d}'.format(i_block, 0),
_ResidualBlock(n_cur_filters, n_out_filters,
dilation=cur_dilation,))
n_cur_filters = n_out_filters
for i_layer in range(1, self.n_layers_per_block):
self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation))
i_block += 1
cur_dilation[0] *= 2
n_out_filters = int(1.5 * n_cur_filters)
self.add_module('res_{:d}_{:d}'.format(i_block, 0),
_ResidualBlock(n_cur_filters, n_out_filters,
dilation=cur_dilation,))
n_cur_filters = n_out_filters
for i_layer in range(1, self.n_layers_per_block):
self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation))
i_block += 1
cur_dilation[0] *= 2
self.add_module('res_{:d}_{:d}'.format(i_block, 0),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation,))
for i_layer in range(1, self.n_layers_per_block):
self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation))
i_block += 1
cur_dilation[0] *= 2
self.add_module('res_{:d}_{:d}'.format(i_block, 0),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation,))
for i_layer in range(1, self.n_layers_per_block):
self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation))
i_block += 1
cur_dilation[0] *= 2
self.add_module('res_{:d}_{:d}'.format(i_block, 0),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation,))
for i_layer in range(1, self.n_layers_per_block):
self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation))
i_block += 1
cur_dilation[0] *= 2
self.add_module('res_{:d}_{:d}'.format(i_block, 0),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation,))
for i_layer in range(1, self.n_layers_per_block):
self.add_module('res_{:d}_{:d}'.format(i_block, i_layer),
_ResidualBlock(n_cur_filters, n_cur_filters,
dilation=cur_dilation))
self.eval()
if self.final_pool_length == 'auto':
self.add_module('mean_pool', nn.AdaptiveAvgPool2d((1, 1)))
else:
pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
self.add_module('mean_pool', AvgPool2dWithConv(
(self.final_pool_length, 1), (1, 1),
dilation=pool_dilation))
self.add_module('conv_classifier',
nn.Conv2d(n_cur_filters, self.n_classes,
(1, 1), bias=True))
self.add_module('softmax', nn.LogSoftmax(dim=1))
self.add_module('squeeze', Expression(squeeze_final_output))
# Initialize all weights
self.apply(lambda module: _weights_init(module, self.conv_weight_init_fn))
# Start in eval mode
self.eval()
def _weights_init(module, conv_weight_init_fn):
"""
initialize weights
"""
classname = module.__class__.__name__
if 'Conv' in classname and classname != "AvgPool2dWithConv":
conv_weight_init_fn(module.weight)
if module.bias is not None:
init.constant_(module.bias, 0)
elif 'BatchNorm' in classname:
init.constant_(module.weight, 1)
init.constant_(module.bias, 0)
class _ResidualBlock(nn.Module):
"""
create a residual learning building block with two stacked 3x3 convlayers as in paper
"""
def __init__(self, in_filters,
out_num_filters,
dilation,
filter_time_length=3,
nonlinearity=elu,
batch_norm_alpha=0.1, batch_norm_epsilon=1e-4):
super(_ResidualBlock, self).__init__()
time_padding = int((filter_time_length - 1) * dilation[0])
assert time_padding % 2 == 0
time_padding = int(time_padding // 2)
dilation = (int(dilation[0]), int(dilation[1]))
assert (out_num_filters - in_filters) % 2 == 0, (
"Need even number of extra channels in order to be able to "
"pad correctly")
self.n_pad_chans = out_num_filters - in_filters
self.conv_1 = nn.Conv2d(
in_filters, out_num_filters, (filter_time_length, 1), stride=(1, 1),
dilation=dilation,
padding=(time_padding, 0))
self.bn1 = nn.BatchNorm2d(
out_num_filters, momentum=batch_norm_alpha, affine=True,
eps=batch_norm_epsilon)
self.conv_2 = nn.Conv2d(
out_num_filters, out_num_filters, (filter_time_length, 1), stride=(1, 1),
dilation=dilation,
padding=(time_padding, 0))
self.bn2 = nn.BatchNorm2d(
out_num_filters, momentum=batch_norm_alpha,
affine=True, eps=batch_norm_epsilon)
# also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
# for resnet options as ilya used them
self.nonlinearity = nonlinearity
def forward(self, x):
stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
if self.n_pad_chans != 0:
zeros_for_padding = torch.autograd.Variable(
torch.zeros(x.size()[0], self.n_pad_chans // 2, x.size()[2], x.size()[3]))
if x.is_cuda:
zeros_for_padding = zeros_for_padding.cuda()
x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
out = self.nonlinearity(x + stack_2)
return out