Source code for braindecode.models.tidnet

from math import ceil

import torch
from torch import nn
from torch.nn import init
from torch.nn.utils import weight_norm

from .modules import Ensure4d, Expression


class _BatchNormZG(nn.BatchNorm2d):
    def reset_parameters(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
        if self.affine:
            self.weight.data.zero_()
            self.bias.data.zero_()


class _ConvBlock2D(nn.Module):
    """Implements Convolution block with order:
    Convolution, dropout, activation, batch-norm
    """
    def __init__(self, in_filters, out_filters, kernel, stride=(1, 1), padding=0, dilation=1,
                 groups=1, drop_prob=0.5, batch_norm=True, activation=nn.LeakyReLU, residual=False):
        super().__init__()
        self.kernel = kernel
        self.activation = activation()
        self.residual = residual

        self.conv = nn.Conv2d(in_filters, out_filters, kernel, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=not batch_norm)
        self.dropout = nn.Dropout2d(p=drop_prob)
        self.batch_norm = _BatchNormZG(out_filters) if residual else nn.BatchNorm2d(out_filters) if\
            batch_norm else lambda x: x

    def forward(self, input):
        res = input
        input = self.conv(input,)
        input = self.dropout(input)
        input = self.activation(input)
        input = self.batch_norm(input)
        return input + res if self.residual else input


class _DenseFilter(nn.Module):
    def __init__(self, in_features, growth_rate, filter_len=5, drop_prob=0.5, bottleneck=2,
                 activation=nn.LeakyReLU, dim=-2):
        super().__init__()
        dim = dim if dim > 0 else dim + 4
        if dim < 2 or dim > 3:
            raise ValueError('Only last two dimensions supported')
        kernel = (filter_len, 1) if dim == 2 else (1, filter_len)

        self.net = nn.Sequential(
            nn.BatchNorm2d(in_features),
            activation(),
            nn.Conv2d(in_features, bottleneck * growth_rate, 1),
            nn.BatchNorm2d(bottleneck * growth_rate),
            activation(),
            nn.Conv2d(bottleneck * growth_rate, growth_rate, kernel,
                      padding=tuple((k // 2 for k in kernel))),
            nn.Dropout2d(drop_prob)
        )

    def forward(self, x):
        return torch.cat((x, self.net(x)), dim=1)


class _DenseSpatialFilter(nn.Module):
    def __init__(self, in_chans, growth, depth, in_ch=1, bottleneck=4, drop_prob=0.0,
                 activation=nn.LeakyReLU, collapse=True):
        super().__init__()
        self.net = nn.Sequential(*[
            _DenseFilter(in_ch + growth * d, growth, bottleneck=bottleneck, drop_prob=drop_prob,
                         activation=activation) for d in range(depth)
        ])
        n_filters = in_ch + growth * depth
        self.collapse = collapse
        if collapse:
            self.channel_collapse = _ConvBlock2D(n_filters, n_filters, (in_chans, 1), drop_prob=0)

    def forward(self, x):
        if len(x.shape) < 4:
            x = x.unsqueeze(1).permute([0, 1, 3, 2])
        x = self.net(x)
        if self.collapse:
            return self.channel_collapse(x).squeeze(-2)
        return x


class _TemporalFilter(nn.Module):
    def __init__(self, in_chans, filters, depth, temp_len, drop_prob=0., activation=nn.LeakyReLU,
                 residual='netwise'):
        super().__init__()
        temp_len = temp_len + 1 - temp_len % 2
        self.residual_style = str(residual)
        net = list()

        for i in range(depth):
            dil = depth - i
            conv = weight_norm(nn.Conv2d(in_chans if i == 0 else filters, filters,
                                         kernel_size=(1, temp_len), dilation=dil,
                                         padding=(0, dil * (temp_len - 1) // 2)))
            net.append(nn.Sequential(
                conv,
                activation(),
                nn.Dropout2d(drop_prob)
            ))
        if self.residual_style.lower() == 'netwise':
            self.net = nn.Sequential(*net)
            self.residual = nn.Conv2d(in_chans, filters, (1, 1))
        elif residual.lower() == 'dense':
            self.net = net

    def forward(self, x):
        if self.residual_style.lower() == 'netwise':
            return self.net(x) + self.residual(x)
        elif self.residual_style.lower() == 'dense':
            for layer in self.net:
                x = torch.cat((x, layer(x)), dim=1)
            return x


class _TIDNetFeatures(nn.Module):
    def __init__(self, s_growth, t_filters, in_chans, input_window_samples, drop_prob, pooling,
                 temp_layers, spat_layers, temp_span, bottleneck, summary):
        super().__init__()
        self.in_chans = in_chans
        self.input_windows_samples = input_window_samples
        self.temp_len = ceil(temp_span * input_window_samples)

        def _permute(x):
            """
            Permutes data:

            from dim:
            batch, chans, time, 1

            to dim:
            batch, 1, chans, time

            """
            return x.permute([0, 3, 1, 2])

        self.temporal = nn.Sequential(
            Ensure4d(),
            Expression(_permute),
            _TemporalFilter(1, t_filters, depth=temp_layers, temp_len=self.temp_len),
            nn.MaxPool2d((1, pooling)),
            nn.Dropout2d(drop_prob),
        )
        summary = input_window_samples // pooling if summary == -1 else summary

        self.spatial = _DenseSpatialFilter(in_chans, s_growth, spat_layers, in_ch=t_filters,
                                           drop_prob=drop_prob, bottleneck=bottleneck)
        self.extract_features = nn.Sequential(
            nn.AdaptiveAvgPool1d(int(summary)),
            nn.Flatten(start_dim=1)
        )

        self._num_features = (t_filters + s_growth * spat_layers) * summary

    @property
    def num_features(self):
        return self._num_features

    def forward(self, x):
        x = self.temporal(x)
        x = self.spatial(x)
        return self.extract_features(x)


[docs]class TIDNet(nn.Module): """Thinker Invariance DenseNet model from Kostas et al 2020. See [TIDNet]_ for details. Parameters ---------- n_classes : int Number of classes. in_chans : int Number of EEG channels. input_window_samples : int Number of samples. s_growth : int DenseNet-style growth factor (added filters per DenseFilter) t_filters : int Number of temporal filters. drop_prob : float Dropout probability pooling : int Max temporal pooling (width and stride) temp_layers : int Number of temporal layers spat_layers : int Number of DenseFilters temp_span : float Percentage of input_window_samples that defines the temporal filter length: temp_len = ceil(temp_span * input_window_samples) e.g A value of 0.05 for temp_span with 1500 input_window_samples will yield a temporal filter of length 75. bottleneck : int Bottleneck factor within Densefilter summary : int Output size of AdaptiveAvgPool1D layer. If set to -1, value will be calculated automatically (input_window_samples // pooling). Notes ----- Code adapted from: https://github.com/SPOClab-ca/ThinkerInvariance/ References ---------- .. [TIDNet] Kostas, D. & Rudzicz, F. Thinker invariance: enabling deep neural networks for BCI across more people. J. Neural Eng. 17, 056008 (2020). doi: 10.1088/1741-2552/abb7a7. """ def __init__(self, in_chans, n_classes, input_window_samples, s_growth=24, t_filters=32, drop_prob=0.4, pooling=15, temp_layers=2, spat_layers=2, temp_span=0.05, bottleneck=3, summary=-1): super().__init__() self.n_classes = n_classes self.in_chans = in_chans self.input_window_samples = input_window_samples self.temp_len = ceil(temp_span * input_window_samples) self.dscnn = _TIDNetFeatures(s_growth=s_growth, t_filters=t_filters, in_chans=in_chans, input_window_samples=input_window_samples, drop_prob=drop_prob, pooling=pooling, temp_layers=temp_layers, spat_layers=spat_layers, temp_span=temp_span, bottleneck=bottleneck, summary=summary) self._num_features = self.dscnn.num_features self.classify = self._create_classifier(self.num_features, n_classes) def _create_classifier(self, incoming, n_classes): classifier = nn.Linear(incoming, n_classes) init.xavier_normal_(classifier.weight) classifier.bias.data.zero_() return nn.Sequential(nn.Flatten(start_dim=1), classifier, nn.LogSoftmax(dim=-1))
[docs] def forward(self, x): """Forward pass. Parameters ---------- x: torch.Tensor Batch of EEG windows of shape (batch_size, n_channels, n_times). """ x = self.dscnn(x) return self.classify(x)
@property def num_features(self): return self._num_features