braindecode.modules.MultiHeadAttention#
- class braindecode.modules.MultiHeadAttention(emb_size, num_heads, dropout=0.0)[source]#
Multi-head self-attention block.
Uses
F.scaled_dot_product_attentionfor optimized attention kernels (flash-attention on CUDA, memory-efficient on other devices).- Parameters:
Examples
>>> import torch >>> from braindecode.modules import MultiHeadAttention >>> module = MultiHeadAttention(emb_size=32, num_heads=4, dropout=0.1) >>> inputs = torch.randn(2, 10, 32) >>> outputs = module(inputs) >>> outputs.shape torch.Size([2, 10, 32])
Methods
- forward(x, mask=None)[source]#
Forward pass.
- Parameters:
x (Tensor) – Input tensor of shape
(batch, seq, emb_size).mask (Tensor, optional) – Attention mask following PyTorch SDPA convention: for boolean masks
Truemeans ignore that position; for float masks the values are added to attention scores before softmax.
- Return type: