braindecode.modules.MultiHeadAttention#
- class braindecode.modules.MultiHeadAttention(emb_size, num_heads, dropout=0.0)[source]#
Multi-head self-attention block.
Uses
F.scaled_dot_product_attentionfor optimized attention kernels (flash-attention on CUDA, memory-efficient on other devices).- Parameters:
Examples
>>> import torch >>> from braindecode.modules import MultiHeadAttention >>> module = MultiHeadAttention(emb_size=32, num_heads=4, dropout=0.1) >>> inputs = torch.randn(2, 10, 32) >>> outputs = module(inputs) >>> outputs.shape torch.Size([2, 10, 32])
Methods