braindecode.modules.MultiHeadAttention#

class braindecode.modules.MultiHeadAttention(emb_size, num_heads, dropout=0.0)[source]#

Multi-head self-attention block.

Uses F.scaled_dot_product_attention for optimized attention kernels (flash-attention on CUDA, memory-efficient on other devices).

Parameters:
  • emb_size (int) – The embedding dimension.

  • num_heads (int) – Number of attention heads. Must evenly divide emb_size.

  • dropout (float, optional) – Dropout probability applied to attention weights. Default: 0.0.

Examples

>>> import torch
>>> from braindecode.modules import MultiHeadAttention
>>> module = MultiHeadAttention(emb_size=32, num_heads=4, dropout=0.1)
>>> inputs = torch.randn(2, 10, 32)
>>> outputs = module(inputs)
>>> outputs.shape
torch.Size([2, 10, 32])

Methods

forward(x, mask=None)[source]#

Forward pass.

Parameters:
  • x (Tensor) – Input tensor of shape (batch, seq, emb_size).

  • mask (Optional[Tensor]) – Attention mask following PyTorch SDPA convention: for boolean masks True means ignore that position; for float masks the values are added to attention scores before softmax.

Return type:

Tensor