braindecode.models.MEDFormer#

class braindecode.models.MEDFormer(n_chans=None, n_outputs=None, n_times=None, chs_info=None, input_window_seconds=None, sfreq=None, patch_len_list=None, embed_dim=128, num_heads=8, drop_prob=0.1, no_inter_attn=False, num_layers=6, dim_feedforward=256, activation_trans=<class 'torch.nn.modules.activation.ReLU'>, single_channel=False, output_attention=True, activation_class=<class 'torch.nn.modules.activation.GELU'>)[source]#

Medformer from Wang et al. (2024) [Medformer2024].

Convolution Large Brain Model

MEDFormer Architecture.

a) Workflow. b) For the input sample \({x}_{\\textrm{in}}\), the authors apply \(n\) different patch lengths in parallel to create patched features \({x}_p^{(i)}\), where \(i\) ranges from 1 to \(n\). Each patch length represents a different granularity. These patched features are linearly transformed into \({x}_e^{(i)}\) and augmented into \(\\widetilde{x}_e^{(i)}\). c) The final patch embedding \({x}^{(i)}\) fuses augmented \(\\widetilde{{x}}_e^{(i)}\) with the positional embedding \({W}_{\\text{pos}}\) and the granularity embedding \({W}_{\\text{gr}}^{(i)}\). Each granularity employs a router \({u}^{(i)}\) to capture aggregated information. Intra-granularity attention focuses within individual granularities, and inter-granularity attention leverages the routers to integrate information across granularities.#

The MedFormer is a multi-granularity patching transformer tailored to medical time-series (MedTS) classification, with an emphasis on EEG and ECG signals. It captures local temporal dynamics, inter-channel correlations, and multi-scale temporal structure through cross-channel patching, multi-granularity embeddings, and two-stage attention [Medformer2024].

Architecture Overview

MedFormer integrates three mechanisms to enhance representation learning [Medformer2024]:

  1. Cross-channel patching. Leverages inter-channel correlations by forming patches across multiple channels and timestamps, capturing multi-timestamp and cross-channel patterns.

  2. Multi-granularity embedding. Extracts features at different temporal scales from patch_len_list, emulating frequency-band behavior without hand-crafted filters.

  3. Two-stage multi-granularity self-attention. Learns intra- and inter-granularity correlations to fuse information across temporal scales.

Macro Components

MEDFormer.enc_embedding (Embedding Layer)

Operations. _ListPatchEmbedding implements cross-channel multi-granularity patching. For each patch length \(L_i\), the input \(\mathbf{x}_{\text{in}} \in \mathbb{R}^{T \times C}\) is segmented into \(N_i\) cross-channel non-overlapping patches \(\mathbf{x}_p^{(i)} \in \mathbb{R}^{N_i \times (L_i \cdot C)}\), where \(N_i = \lceil T/L_i \rceil\). Each patch is linearly projected via _CrossChannelTokenEmbedding to obtain \(\mathbf{x}_e^{(i)} \in \mathbb{R}^{N_i \times D}\). Data augmentations (masking, jittering) produce augmented embeddings \(\tilde{\mathbf{x}}_e^{(i)}\). The final embedding combines augmented patches, fixed positional embeddings (_PositionalEmbedding), and learnable granularity embeddings \(\mathbf{W}_{\text{gr}}^{(i)}\):

\[\mathbf{x}^{(i)} = \tilde{\mathbf{x}}_e^{(i)} + \mathbf{W}_{\text{pos}}[1:N_i] + \mathbf{W}_{\text{gr}}^{(i)}\]

Additionally, a router token is initialized for each granularity:

\[\mathbf{u}^{(i)} = \mathbf{W}_{\text{pos}}[N_i+1] + \mathbf{W}_{\text{gr}}^{(i)}\]

Role. Converts raw input into granularity-specific patch embeddings \(\{\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(n)}\}\) and router embeddings \(\{\mathbf{u}^{(1)}, \ldots, \mathbf{u}^{(n)}\}\) for multi-scale processing.

MEDFormer.encoder (Transformer Encoder Stack)

Operations. A stack of _EncoderLayer modules, each containing a _MedformerLayer that implements two-stage self-attention. The two-stage mechanism splits self-attention into:

(a) Intra-Granularity Self-Attention. For granularity \(i\), the patch embedding \(\mathbf{x}^{(i)} \in \mathbb{R}^{N_i \times D}\) and router embedding \(\mathbf{u}^{(i)} \in \mathbb{R}^{1 \times D}\) are concatenated:

\[\mathbf{z}^{(i)} = [\mathbf{x}^{(i)} \| \mathbf{u}^{(i)}] \in \mathbb{R}^{(N_i+1) \times D}\]

Self-attention is applied to update both embeddings:

\[\begin{split}\mathbf{x}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{x}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})\\ \mathbf{u}^{(i)} &\leftarrow \text{Attn}_{\text{intra}}(\mathbf{u}^{(i)}, \mathbf{z}^{(i)}, \mathbf{z}^{(i)})\end{split}\]

This captures temporal features within each granularity independently.

(b) Inter-Granularity Self-Attention. All router embeddings are concatenated:

\[\mathbf{U} = [\mathbf{u}^{(1)} \| \mathbf{u}^{(2)} \| \cdots \| \mathbf{u}^{(n)}] \in \mathbb{R}^{n \times D}\]

Self-attention among routers exchanges information across granularities:

\[\mathbf{u}^{(i)} \leftarrow \text{Attn}_{\text{inter}}(\mathbf{u}^{(i)}, \mathbf{U}, \mathbf{U})\]

Role. Learns representations and correlations within and across temporal scales while reducing complexity from \(O((\sum_i N_i)^2)\) to \(O(\sum_i N_i^2 + n^2)\) through the router mechanism.

Temporal, Spatial, and Spectral Encoding

  • Temporal: Multiple patch lengths in patch_len_list capture features at several temporal granularities, while intra-granularity attention supports long-range temporal dependencies.

  • Spatial: Cross-channel patching embeds inter-channel dependencies by applying kernels that span every input channel.

  • Spectral: Differing patch lengths simulate multiple sampling frequencies analogous to clinically relevant bands (e.g., alpha, beta, gamma).

Additional Mechanisms

  • Granularity router: Each granularity \(i\) receives a dedicated router token \(\\mathbf{u}^{(i)}\). Intra-attention updates the token, and inter-attention exchanges aggregated information across scales.

  • Complexity: Router-mediated two-stage attention maintains \(O(T^2)\) complexity for suitable patch lengths (e.g., power series), preserving transformer-like efficiency while modeling multiple granularities.

Parameters:
  • n_chans (int) – Number of EEG channels.

  • n_outputs (int) – Number of outputs of the model. This is the number of classes in the case of classification.

  • n_times (int) – Number of time samples of the input window.

  • chs_info (list of dict) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

  • input_window_seconds (float) – Length of the input window in seconds.

  • sfreq (float) – Sampling frequency of the EEG recordings.

  • patch_len_list (list of int, optional) – Patch lengths for multi-granularity patching; each entry selects a temporal scale. The default is [14, 44, 45].

  • embed_dim (int, optional) – Embedding dimensionality. The default is 128.

  • num_heads (int, optional) – Number of attention heads, which must divide d_model. The default is 8.

  • drop_prob (float, optional) – Dropout probability. The default is 0.1.

  • no_inter_attn (bool, optional) – If True, disables inter-granularity attention. The default is False.

  • num_layers (int, optional) – Number of encoder layers. The default is 6.

  • dim_feedforward (int, optional) – Feedforward dimensionality. The default is 256.

  • activation_trans (nn.Module, optional) – Activation module used in transformer encoder layers. The default is nn.ReLU.

  • single_channel (bool, optional) – If True, processes each channel independently, increasing capacity and cost. The default is False.

  • output_attention (bool, optional) – If True, returns attention weights for interpretability. The default is True.

  • activation_class (nn.Module, optional) – Activation used in the final classification layer. The default is nn.GELU.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

  • MedFormer outperforms strong baselines across six metrics on five MedTS datasets in a subject-independent evaluation [Medformer2024].

  • Cross-channel patching provides the largest F1 improvement in ablation studies (average +6.10%), highlighting its importance for MedTS tasks [Medformer2024].

  • Setting no_inter_attn to True disables inter-granularity attention while retaining intra-granularity attention.

References

[Medformer2024] (1,2,3,4,5)

Wang, Y., Huang, N., Li, T., Yan, Y., & Zhang, X. (2024). Medformer: A Multi-Granularity Patching Transformer for Medical Time-Series Classification. In A. Globerson, L. Mackey, D. Belgrave, A. Fan, U. Paquet, J. Tomczak, & C. Zhang (Eds.), Advances in Neural Information Processing Systems (Vol. 37, pp. 36314-36341). doi:10.52202/079017-1145.

Methods

forward(x)[source]#

Forward pass of the Medformer model.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, n_chans, n_times).

Returns:

Output tensor of shape (batch_size, n_outputs).

Return type:

torch.Tensor