Source code for braindecode.models.eegitnet

# Authors: Ghaith Bouallegue <ghaithbouallegue@gmail.com>
#
# License: BSD-3
import torch
from torch import nn

from .modules import Ensure4d, Expression


def _permute(x):
    """Permute data.

    Input dimensions: (batch, channels, time, 1)
    Output dimiensions: (batch, 1, channels, time)
    """
    return x.permute([0, 3, 1, 2])


class _DepthwiseConv2d(torch.nn.Conv2d):
    def __init__(
            self,
            in_channels,
            depth_multiplier=2,
            kernel_size=3,
            stride=1,
            padding=0,
            dilation=1,
            bias=True,
            padding_mode="zeros",
    ):
        out_channels = in_channels * depth_multiplier
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=bias,
            padding_mode=padding_mode,
        )


class _InceptionBlock(nn.Module):
    def __init__(self, branches):
        super().__init__()
        self.branches = nn.ModuleList(branches)

    def forward(self, x):
        return torch.cat([branch(x) for branch in self.branches], 1)


class _TCBlock(nn.Module):
    def __init__(self, in_ch, kernel_length, dialation, padding, drop_prob=0.4):
        super().__init__()
        self.pad = padding
        self.tc1 = nn.Sequential(
            _DepthwiseConv2d(
                in_ch,
                kernel_size=(1, kernel_length),
                depth_multiplier=1,
                dilation=(1, dialation),
                bias=False,
                padding="valid",
                ),
            nn.BatchNorm2d(in_ch),
            nn.ELU(),
            nn.Dropout(drop_prob),
        )

        self.tc2 = nn.Sequential(
            _DepthwiseConv2d(
                in_ch,
                kernel_size=(1, kernel_length),
                depth_multiplier=1,
                dilation=(1, dialation),
                bias=False,
                padding="valid",
            ),
            nn.BatchNorm2d(in_ch),
            nn.ELU(),
            nn.Dropout(drop_prob),
        )

    def forward(self, x):
        residual = x
        paddings = (self.pad, 0, 0, 0, 0, 0, 0, 0)
        x = nn.functional.pad(x, paddings)
        x = self.tc1(x)
        x = nn.functional.pad(x, paddings)
        x = self.tc2(x) + residual
        return x


[docs]class EEGITNet(nn.Sequential): """EEG-ITNet: An Explainable Inception Temporal Convolutional Network for motor imagery classification from Salami et. al 2022. See [Salami2022]_ for details. Code adapted from https://github.com/abbassalami/eeg-itnet Parameters ---------- n_classes: int number of outputs of the decoding task (for example number of classes in classification) n_in_chans: int number of input EEG channels input_window_samples : int Number of time samples. drop_prob: float Dropout probability. References ---------- .. [Salami2022] A. Salami, J. Andreu-Perez and H. Gillmeister, "EEG-ITNet: An Explainable Inception Temporal Convolutional Network for motor imagery classification," in IEEE Access, doi: 10.1109/ACCESS.2022.3161489. Notes ----- This implementation is not guaranteed to be correct, has not been checked by original authors, only reimplemented from the paper based on author implementation. """ def __init__(self, n_classes, in_channels, input_window_samples, drop_prob=0.4): super().__init__() # ======== Handling EEG input ======================== self.add_module( "input_preprocess", nn.Sequential(Ensure4d(), Expression(_permute)) ) # ======== Inception branches ======================== block11 = self._get_inception_branch( in_channels=in_channels, out_channels=2, kernel_length=16 ) block12 = self._get_inception_branch( in_channels=in_channels, out_channels=4, kernel_length=32 ) block13 = self._get_inception_branch( in_channels=in_channels, out_channels=8, kernel_length=64 ) self.add_module("inception_block", _InceptionBlock((block11, block12, block13))) self.pool1 = self.add_module("pooling", nn.Sequential( nn.AvgPool2d(kernel_size=(1, 4)), nn.Dropout(drop_prob))) # =========== TC blocks ===================== self.add_module( "TC_block1", _TCBlock(in_ch=14, kernel_length=4, dialation=1, padding=3, drop_prob=drop_prob) ) # ================================ self.add_module( "TC_block2", _TCBlock(in_ch=14, kernel_length=4, dialation=2, padding=6, drop_prob=drop_prob) ) # ================================ self.add_module( "TC_block3", _TCBlock(in_ch=14, kernel_length=4, dialation=4, padding=12, drop_prob=drop_prob) ) # ================================ self.add_module( "TC_block4", _TCBlock(in_ch=14, kernel_length=4, dialation=8, padding=24, drop_prob=drop_prob) ) # ============= Dimensionality reduction =================== self.add_module("dim_reduction", nn.Sequential( nn.Conv2d(14, 28, kernel_size=(1, 1)), nn.BatchNorm2d(28), nn.ELU(), nn.AvgPool2d((1, 4)), nn.Dropout(drop_prob))) # ============== Classifier ================== self.add_module("classifier", nn.Sequential( torch.nn.Flatten(), nn.Linear(int(int(input_window_samples / 4) / 4) * 28, n_classes), nn.Softmax(dim=1))) @staticmethod def _get_inception_branch(in_channels, out_channels, kernel_length, depth_multiplier=1): return nn.Sequential( nn.Conv2d( 1, out_channels, kernel_size=(1, kernel_length), padding="same", bias=False ), nn.BatchNorm2d(out_channels), _DepthwiseConv2d( out_channels, kernel_size=(in_channels, 1), depth_multiplier=depth_multiplier, bias=False, padding="valid", ), nn.BatchNorm2d(out_channels), nn.ELU())