braindecode.models.EMG2QwertyNet#

class braindecode.models.EMG2QwertyNet(n_outputs=99, n_chans=32, sfreq=2000.0, num_bands=2, electrodes_per_band=16, n_fft=64, hop_length=16, log_eps=1e-06, mlp_features=(384, ), rotation_offsets=(-1, 0, 1), pooling='mean', block_channels=(24, 24, 24, 24), kernel_width=32, log_softmax=False, activation=<class 'torch.nn.modules.activation.ReLU'>, drop_prob=0.0, n_times=None, input_window_seconds=None, chs_info=None)[source]#

Decoder mapping surface electromyography (sEMG) to keystrokes (emg2qwerty) [emg2qwerty2024].

Convolution

Time-Depth-Separable (TDS) [hannun2019tds] convolutional encoder followed by a Connectionist Temporal Classification (CTC) head [graves2006ctc]. Takes raw 32-channel sEMG (2 wristbands × 16 electrodes) at 2 kHz and emits per-frame scores over the 99-class typing vocabulary (98 keys + 1 CTC blank).

Pipeline

  1. Log-spectrogram front-end: per-channel Short-Time Fourier Transform (stft()) with Hann window, center=False, then squared magnitude and log10(p + log_eps). With the defaults (n_fft=64, hop_length=16, sfreq=2000) the output frame rate is 125 Hz. No trainable parameters.

  2. Spectrogram BatchNorm: BatchNorm2d over the (batch, freq, time) slice for each of the num_bands × electrodes_per_band channels.

  3. Per-band rotation-invariant multi-layer perceptron (MLP): for each band, a shared MLP is applied to circular rolls (-1, 0, +1) of the electrode axis, then mean-pooled.

  4. TDS convolutional encoder: stack of len(block_channels) TDS conv blocks interleaved with feedforward blocks. No temporal padding, so each conv block strips kernel_width - 1 frames.

  5. Linear classification head: Linear projecting to n_outputs, optionally followed by log_softmax() (off by default; braindecode models return logits).

Output

Returns (batch, T_out, n_outputs). With n_times=8000 and defaults, T_out=373. For CTCLoss, transpose to (T_out, batch, n_outputs); use compute_output_lengths() for emission lengths.

Paper training recipe

  • Loss: CTCLoss on log-softmax outputs.

  • Vocabulary: 98 keys + 1 blank (n_outputs = 99).

  • Optimizer: Adam, lr 1e-3, weight decay 0.

  • Schedule: 10-epoch linear warmup from lr 1e-8, then cosine annealing to 1e-6 over 150 epochs. The slow warmup is required. Without it, CTC collapses to all-blank within one epoch (a trivial local minimum).

  • Augmentation: per-band electrode rotations by -1/0/+1 positions, ±60-sample temporal jitter, and SpecAugment [park2019specaug] on the log-spectrogram.

  • Decoding: greedy CTC. Upstream also reports a 6-gram KenLM beam decoder, not ported here.

Warning

The rotation-invariant MLP assumes circular adjacency of the electrodes within each band (the wristband geometry of the paper, electrodes_per_band=16). For arbitrary EEG montages the symmetry does not hold and this model should not be used as-is.

Warning

License: noncommercial use only. This module is a derivative of Meta’s reference implementation released under CC BY-NC 4.0, the same license as the upstream repository. Not covered by braindecode’s BSD-3 license. Must not be used in commercial products or services. Pretrained weights from the upstream release carry the same restriction.

Added in version 1.5.

Parameters:
  • n_outputs (int) – Vocabulary size for CTC, including the blank class. Defaults to 99 (98 keys + 1 blank).

  • n_chans (int) – Number of EMG channels. Must equal num_bands * electrodes_per_band (default 32 = 2 * 16).

  • sfreq (float) – Sampling frequency in Hz. Defaults to 2000. n_fft and hop_length defaults are calibrated for this rate; pass matching values when changing sfreq.

  • num_bands (int) – Number of EMG bands (e.g. one per wristband). Defaults to 2.

  • electrodes_per_band (int) – Number of electrodes per band. Defaults to 16. The rotation-invariant MLP assumes circular adjacency along this axis.

  • n_fft (int) – STFT window size in samples.

  • hop_length (int) – STFT hop in samples.

  • log_eps (float) – Floor added inside log10(power + log_eps) to keep the log finite at silent samples. Defaults to 1e-6.

  • mlp_features (Sequence[int]) – Hidden sizes of the rotation-invariant MLP. Output dim per band is mlp_features[-1].

  • rotation_offsets (Sequence[int]) – Circular electrode offsets used to enforce approximate rotation invariance. Defaults to (-1, 0, 1).

  • pooling (str) – Pool reduction across the rotation rolls. Defaults to "mean".

  • block_channels (Sequence[int]) – Channel count per TDS convolutional block. The model’s internal num_features = num_bands * mlp_features[-1] must be evenly divisible by each entry.

  • kernel_width (int) – Temporal kernel size of each TDS convolutional block.

  • log_softmax (bool) – If True, apply log_softmax() to the emissions. Disabled by default (braindecode models return logits).

  • activation (type[Module]) – Activation class used inside the rotation-invariant MLP and the TDS blocks. Defaults to ReLU (matches upstream emg2qwerty). Pass any non-parametrized activation class (GELU, SiLU, …) for ablations.

  • drop_prob (float) – Dropout probability applied inside each TDS feedforward block, once after the activation between the two Linear layers and again after the second Linear. Default 0.0 matches the upstream paper recipe (no dropout). Set > 0 for regularized training.

  • n_times (int | None) – Number of time samples of the input window.

  • input_window_seconds (float | None) – Length of the input window in seconds.

  • chs_info (list | None) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.

Examples

Build the model and run a forward pass on a 4-second sEMG batch:

import torch
from braindecode.models import EMG2QwertyNet

model = EMG2QwertyNet(
    n_outputs=99, n_chans=32, n_times=8000, sfreq=2000,
)
x = torch.randn(2, 32, 8000)
emissions = model(x)

Compute a CTC loss on the emissions:

import torch.nn as nn
import torch.nn.functional as F

log_probs = F.log_softmax(emissions, dim=-1).transpose(0, 1)
input_lengths = model.compute_output_lengths(
    torch.tensor([8000, 8000])
)
targets = torch.randint(0, 98, (2, 20), dtype=torch.long)
target_lengths = torch.tensor([20, 15], dtype=torch.int32)
loss = nn.CTCLoss(blank=98, zero_infinity=True)(
    log_probs, targets, input_lengths, target_lengths,
)

References

[emg2qwerty2024]

Sivakumar, V., Seely, J., Du, A., Bittner, S. R., Berenzweig, A., Bolarinwa, A., Gramfort, A., Mandel, M. I., 2024. emg2qwerty: A Large Dataset with Baselines for Touch Typing using Surface Electromyography. Advances in Neural Information Processing Systems 37, 91373-91389.

[graves2006ctc]

Graves, A., Fernandez, S., Gomez, F., Schmidhuber, J., 2006. Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks. Proc. ICML, 369-376.

[hannun2019tds]

Hannun, A., Lee, A., Xu, Q., Collobert, R., 2019. Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions. arXiv:1904.02619.

[park2019specaug]

Park, D. S. et al., 2019. SpecAugment: a simple data augmentation method for automatic speech recognition. Proc. Interspeech, 2613-2617.

Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:

pip install braindecode[hub]

Pushing a model to the Hub:

from braindecode.models import EMG2QwertyNet

# Train your model
model = EMG2QwertyNet(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...

# Push to the Hub
model.push_to_hub(
    repo_id="username/my-emg2qwertynet-model",
    commit_message="Initial model upload",
)

Loading a model from the Hub:

from braindecode.models import EMG2QwertyNet

# Load pretrained model
model = EMG2QwertyNet.from_pretrained("username/my-emg2qwertynet-model")

# Load with a different number of outputs (head is rebuilt automatically)
model = EMG2QwertyNet.from_pretrained("username/my-emg2qwertynet-model", n_outputs=4)

Extracting features and replacing the head:

import torch

x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]

# Replace the classification head
model.reset_head(n_outputs=10)

Saving and restoring full configuration:

import json

config = model.get_config()            # all __init__ params
with open("config.json", "w") as f:
    json.dump(config, f)

model2 = EMG2QwertyNet.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See Loading and Adapting Pretrained Foundation Models for a complete tutorial.

Methods

compute_output_lengths(input_lengths)[source]#

Map per-sample input lengths to CTC emission lengths.

T_out = (T - n_fft) // hop_length + 1 - n_conv_blocks * (kernel_width - 1), clamped to zero.

Parameters:

input_lengths (Tensor) – The description is missing.

Return type:

Tensor

forward(x)[source]#

Run the full pipeline.

Parameters:

x (Tensor) – Raw EMG of shape (batch, n_chans=32, n_times). n_times must be at least the encoder’s receptive field, n_fft + n_conv_blocks * (kernel_width - 1) * hop_length.

Returns:

emissions – Shape (batch, T_out, n_outputs). Log-probabilities if log_softmax=True, otherwise logits.

Return type:

Tensor

get_output_shape()[source]#

Shape of forward output for a batch of size 1.

Uses the user-supplied n_times so this method’s reported shape is consistent with what forward() accepts. If the configured n_times is below the encoder’s receptive field, forward() would raise; we mirror that here. Falls back to the receptive-field minimum only when n_times was not set at construction time.

Returns:

output_shape – shape of the network output for batch_size==1 (1, …)

Return type:

tuple[int, ...]

reset_head(n_outputs)[source]#

Replace the classification head for a new vocabulary size.

The replacement Linear inherits the existing head’s dtype and device so a subsequent forward() does not crash after model.double() or model.to(device). The captured init config (get_config()) is also kept in sync so save/load round-trips rebuild the new head.

Parameters:

n_outputs (int) – New number of output classes.

Return type:

None

Examples

>>> from braindecode.models import BENDR
>>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4)
>>> model.reset_head(10)
>>> model.n_outputs
10

Added in version 1.4.