Source code for braindecode.models.tcn

# Authors: Patryk Chrabaszcz
#          Lukas Gemein <l.gemein@gmail.com>
#
# License: BSD-3

from torch import nn
from torch.nn import init
from torch.nn.utils import weight_norm

from .modules import Ensure4d, Expression
from .functions import squeeze_final_output


[docs]class TCN(nn.Module): """Temporal Convolutional Network (TCN) from Bai et al 2018. See [Bai2018]_ for details. Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py Parameters ---------- n_in_chans: int number of input EEG channels n_outputs: int number of outputs of the decoding task (for example number of classes in classification) n_filters: int number of output filters of each convolution n_blocks: int number of temporal blocks in the network kernel_size: int kernel size of the convolutions drop_prob: float dropout probability add_log_softmax: bool whether to add a log softmax layer References ---------- .. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018). An empirical evaluation of generic convolutional and recurrent networks for sequence modeling. arXiv preprint arXiv:1803.01271. """ def __init__(self, n_in_chans, n_outputs, n_blocks, n_filters, kernel_size, drop_prob, add_log_softmax): super().__init__() self.ensuredims = Ensure4d() t_blocks = nn.Sequential() for i in range(n_blocks): n_inputs = n_in_chans if i == 0 else n_filters dilation_size = 2 ** i t_blocks.add_module("temporal_block_{:d}".format(i), TemporalBlock( n_inputs=n_inputs, n_outputs=n_filters, kernel_size=kernel_size, stride=1, dilation=dilation_size, padding=(kernel_size - 1) * dilation_size, drop_prob=drop_prob )) self.temporal_blocks = t_blocks self.fc = nn.Linear(in_features=n_filters, out_features=n_outputs) if add_log_softmax: self.log_softmax = nn.LogSoftmax(dim=1) self.squeeze = Expression(squeeze_final_output) self.min_len = 1 for i in range(n_blocks): dilation = 2 ** i self.min_len += 2 * (kernel_size - 1) * dilation # start in eval mode self.eval()
[docs] def forward(self, x): """Forward pass. Parameters ---------- x: torch.Tensor Batch of EEG windows of shape (batch_size, n_channels, n_times). """ x = self.ensuredims(x) # x is in format: B x C x T x 1 (batch_size, _, time_size, _) = x.size() assert time_size >= self.min_len # remove empty trailing dimension x = x.squeeze(3) x = self.temporal_blocks(x) # Convert to: B x T x C x = x.transpose(1, 2).contiguous() fc_out = self.fc(x.view(batch_size * time_size, x.size(2))) if hasattr(self, "log_softmax"): fc_out = self.log_softmax(fc_out) fc_out = fc_out.view(batch_size, time_size, fc_out.size(1)) out_size = 1 + max(0, time_size - self.min_len) out = fc_out[:, -out_size:, :].transpose(1, 2) # re-add 4th dimension for compatibility with braindecode return self.squeeze(out[:, :, :, None])
class TemporalBlock(nn.Module): def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, drop_prob): super().__init__() self.conv1 = weight_norm(nn.Conv1d( n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)) self.chomp1 = Chomp1d(padding) self.relu1 = nn.ReLU() self.dropout1 = nn.Dropout2d(drop_prob) self.conv2 = weight_norm(nn.Conv1d( n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)) self.chomp2 = Chomp1d(padding) self.relu2 = nn.ReLU() self.dropout2 = nn.Dropout2d(drop_prob) self.downsample = (nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None) self.relu = nn.ReLU() init.normal_(self.conv1.weight, 0, 0.01) init.normal_(self.conv2.weight, 0, 0.01) if self.downsample is not None: init.normal_(self.downsample.weight, 0, 0.01) def forward(self, x): out = self.conv1(x) out = self.chomp1(out) out = self.relu1(out) out = self.dropout1(out) out = self.conv2(out) out = self.chomp2(out) out = self.relu2(out) out = self.dropout2(out) res = x if self.downsample is None else self.downsample(x) return self.relu(out + res) class Chomp1d(nn.Module): def __init__(self, chomp_size): super().__init__() self.chomp_size = chomp_size def extra_repr(self): return 'chomp_size={}'.format(self.chomp_size) def forward(self, x): return x[:, :, :-self.chomp_size].contiguous()