Source code for braindecode.models.hybrid

# Authors: Robin Schirrmeister <robintibor@gmail.com>
#
# License: BSD (3-clause)

import torch
from torch import nn
from torch.nn import ConstantPad2d

from .deep4 import Deep4Net
from .util import to_dense_prediction_model
from .shallow_fbcsp import ShallowFBCSPNet


[docs]class HybridNet(nn.Module): """Hybrid ConvNet model from Schirrmeister et al 2017. See [Schirrmeister2017]_ for details. References ---------- .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730 """ def __init__(self, in_chans, n_classes, input_window_samples): super(HybridNet, self).__init__() deep_model = Deep4Net( in_chans, n_classes, n_filters_time=20, n_filters_spat=30, n_filters_2=40, n_filters_3=50, n_filters_4=60, input_window_samples=input_window_samples, final_conv_length=2, ) shallow_model = ShallowFBCSPNet( in_chans, n_classes, input_window_samples=input_window_samples, n_filters_time=30, n_filters_spat=40, filter_time_length=28, final_conv_length=29, ) reduced_deep_model = nn.Sequential() for name, module in deep_model.named_children(): if name == "conv_classifier": new_conv_layer = nn.Conv2d( module.in_channels, 60, kernel_size=module.kernel_size, stride=module.stride, ) reduced_deep_model.add_module("deep_final_conv", new_conv_layer) break reduced_deep_model.add_module(name, module) reduced_shallow_model = nn.Sequential() for name, module in shallow_model.named_children(): if name == "conv_classifier": new_conv_layer = nn.Conv2d( module.in_channels, 40, kernel_size=module.kernel_size, stride=module.stride, ) reduced_shallow_model.add_module( "shallow_final_conv", new_conv_layer ) break reduced_shallow_model.add_module(name, module) to_dense_prediction_model(reduced_deep_model) to_dense_prediction_model(reduced_shallow_model) self.reduced_deep_model = reduced_deep_model self.reduced_shallow_model = reduced_shallow_model self.final_conv = nn.Conv2d( 100, n_classes, kernel_size=(1, 1), stride=1 )
[docs] def forward(self, x): """Forward pass. Parameters ---------- x: torch.Tensor Batch of EEG windows of shape (batch_size, n_channels, n_times). """ deep_out = self.reduced_deep_model(x) shallow_out = self.reduced_shallow_model(x) n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2] if n_diff_deep_shallow < 0: deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)( deep_out ) elif n_diff_deep_shallow > 0: shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)( shallow_out ) merged_out = torch.cat((deep_out, shallow_out), dim=1) linear_out = self.final_conv(merged_out) softmaxed = nn.LogSoftmax(dim=1)(linear_out) squeezed = softmaxed.squeeze(3) return squeezed