Source code for braindecode.models.util

# Authors: Robin Schirrmeister <robintibor@gmail.com>
#          Hubert Banville <hubert.jbanville@gmail.com>
#
# License: BSD (3-clause)

import torch
import numpy as np
from scipy.special import log_softmax


def to_dense_prediction_model(model, axis=(2, 3)):
    """
    Transform a sequential model with strides to a model that outputs
    dense predictions by removing the strides and instead inserting dilations.
    Modifies model in-place.

    Parameters
    ----------
    model: torch.nn.Module
        Model which modules will be modified
    axis: int or (int,int)
        Axis to transform (in terms of intermediate output axes)
        can either be 2, 3, or (2,3).

    Notes
    -----
    Does not yet work correctly for average pooling.
    Prior to version 0.1.7, there had been a bug that could move strides
    backwards one layer.

    """
    if not hasattr(axis, "__len__"):
        axis = [axis]
    assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis"
    axis = np.array(axis) - 2
    stride_so_far = np.array([1, 1])
    for module in model.modules():
        if hasattr(module, "dilation"):
            assert module.dilation == 1 or (module.dilation == (1, 1)), (
                "Dilation should equal 1 before conversion, maybe the model is "
                "already converted?"
            )
            new_dilation = [1, 1]
            for ax in axis:
                new_dilation[ax] = int(stride_so_far[ax])
            module.dilation = tuple(new_dilation)
        if hasattr(module, "stride"):
            if not hasattr(module.stride, "__len__"):
                module.stride = (module.stride, module.stride)
            stride_so_far *= np.array(module.stride)
            new_stride = list(module.stride)
            for ax in axis:
                new_stride[ax] = 1
            module.stride = tuple(new_stride)


[docs]def get_output_shape(model, in_chans, input_window_samples): """Returns shape of neural network output for batch size equal 1. Returns ------- output_shape: tuple shape of the network output for `batch_size==1` (1, ...) """ with torch.no_grad(): dummy_input = torch.ones( 1, in_chans, input_window_samples, dtype=next(model.parameters()).dtype, device=next(model.parameters()).device, ) output_shape = model(dummy_input).shape return output_shape
def _pad_shift_array(x, stride=1): """Zero-pad and shift rows of a 3D array. E.g., used to align predictions of corresponding windows in sequence-to-sequence models. Parameters ---------- x : np.ndarray Array of shape (n_rows, n_classes, n_windows). stride : int Number of non-overlapping elements between two consecutive sequences. Returns ------- np.ndarray : Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows) where each row is obtained by zero-padding the corresponding row in ``x`` before and after in the last dimension. """ if x.ndim != 3: raise NotImplementedError( 'x must be of shape (n_rows, n_classes, n_windows), got ' f'{x.shape}') x_padded = np.pad(x, ((0, 0), (0, 0), (0, (x.shape[0] - 1) * stride))) orig_strides = x_padded.strides new_strides = (orig_strides[0] - stride * orig_strides[2], orig_strides[1], orig_strides[2]) return np.lib.stride_tricks.as_strided(x_padded, strides=new_strides) def aggregate_probas(logits, n_windows_stride=1): """Aggregate predicted probabilities with self-ensembling. Aggregate window-wise predicted probabilities obtained on overlapping sequences of windows using multiplicative voting as described in [Phan2018]_. Parameters ---------- logits : np.ndarray Array of shape (n_sequences, n_classes, n_windows) containing the logits (i.e. the raw unnormalized scores for each class) for each window of each sequence. n_windows_stride : int Number of windows between two consecutive sequences. Default is 1 (maximally overlapping sequences). Returns ------- np.ndarray : Array of shape ((n_rows - 1) * stride + n_windows, n_classes) containing the aggregated predicted probabilities for each window contained in the input sequences. References ---------- .. [Phan2018] Phan, H., Andreotti, F., Cooray, N., Chén, O. Y., & De Vos, M. (2018). Joint classification and prediction CNN framework for automatic sleep stage classification. IEEE Transactions on Biomedical Engineering, 66(5), 1285-1296. """ log_probas = log_softmax(logits, axis=1) return _pad_shift_array(log_probas, stride=n_windows_stride).sum(axis=0).T