braindecode.models.EMG2QwertyNet#

class braindecode.models.EMG2QwertyNet(n_outputs=99, n_chans=32, sfreq=2000.0, num_bands=2, electrodes_per_band=16, n_fft=64, hop_length=16, log_eps=1e-06, mlp_features=(384, ), rotation_offsets=(-1, 0, 1), pooling='mean', block_channels=(24, 24, 24, 24), kernel_width=32, log_softmax=False, activation=<class 'torch.nn.modules.activation.ReLU'>, drop_prob=0.0, spec_augment=False, n_time_masks=3, time_mask_param=25, n_freq_masks=2, freq_mask_param=4, spec_augment_prob=1.0, return_feature=False, n_times=None, input_window_seconds=None, chs_info=None)[source]#

Decoder mapping surface electromyography (sEMG) to keystrokes (emg2qwerty) [emg2qwerty2024].

Convolution

Time-Depth-Separable (TDS) [hannun2019tds] convolutional encoder followed by a Connectionist Temporal Classification (CTC) head [graves2006ctc]. Takes raw 32-channel sEMG (2 wristbands × 16 electrodes) at 2 kHz and emits per-frame scores over the 99-class typing vocabulary (98 keys + 1 CTC blank).

Pipeline

  1. Log-spectrogram front-end: per-channel Short-Time Fourier Transform (stft()) with Hann window, center=False, then squared magnitude and log10(p + log_eps). With the defaults (n_fft=64, hop_length=16, sfreq=2000) the output frame rate is 125 Hz. No trainable parameters.

  2. Spectrogram BatchNorm: BatchNorm2d over the (batch, freq, time) slice for each of the num_bands × electrodes_per_band channels.

  3. Per-band rotation-invariant multi-layer perceptron (MLP): for each band, a shared MLP is applied to circular rolls (-1, 0, +1) of the electrode axis, then mean-pooled.

  4. TDS convolutional encoder: stack of len(block_channels) TDS conv blocks interleaved with feedforward blocks. No temporal padding, so each conv block strips kernel_width - 1 frames.

  5. Linear classification head: Linear projecting to n_outputs, optionally followed by log_softmax() (off by default; braindecode models return logits).

Output

Returns (batch, T_out, n_outputs). With n_times=8000 and defaults, T_out=373. For CTCLoss, transpose to (T_out, batch, n_outputs); use compute_output_lengths() for emission lengths. Pass return_features=True to return the pre-classifier encoder representation as a {"features": (batch, T_out, num_features), "cls_token": None} dict, matching the BIOT / signal-JEPA convention used by downstream wrappers (e.g. neuroai’s DownstreamWrapperModel).

Paper training recipe

  • Loss: CTCLoss on log-softmax outputs.

  • Vocabulary: 98 keys + 1 blank (n_outputs = 99).

  • Optimizer: Adam, lr 1e-3, weight decay 0.

  • Schedule: 10-epoch linear warmup from lr 1e-8, then cosine annealing to 1e-6 over 150 epochs. The slow warmup is required. Without it, CTC collapses to all-blank within one epoch (a trivial local minimum).

  • Augmentation: per-band electrode rotations by -1/0/+1 positions, ±60-sample temporal jitter, and SpecAugment [park2019specaug] on the log-spectrogram. SpecAugment is built into the model (spec_augment=True) and only fires in training mode; the time/frequency-jitter pieces are dataset-side augmentations.

  • Decoding: greedy CTC. Upstream also reports a 6-gram KenLM beam decoder, not ported here.

Warning

The rotation-invariant MLP assumes circular adjacency of the electrodes within each band (the wristband geometry of the paper, electrodes_per_band=16). For arbitrary EEG montages the symmetry does not hold and this model should not be used as-is.

Warning

License: noncommercial use only. This module is a derivative of Meta’s reference implementation released under CC BY-NC 4.0, the same license as the upstream repository. Not covered by braindecode’s BSD-3 license. Must not be used in commercial products or services. Pretrained weights from the upstream release carry the same restriction.

Added in version 1.5.

Parameters:
  • n_outputs (int) – Vocabulary size for CTC, including the blank class. Defaults to 99 (98 keys + 1 blank).

  • n_chans (int) – Number of EMG channels. Must equal num_bands * electrodes_per_band (default 32 = 2 * 16).

  • sfreq (float) – Sampling frequency in Hz. Defaults to 2000. n_fft and hop_length defaults are calibrated for this rate; pass matching values when changing sfreq.

  • num_bands (int) – Number of EMG bands (e.g. one per wristband). Defaults to 2.

  • electrodes_per_band (int) – Number of electrodes per band. Defaults to 16. The rotation-invariant MLP assumes circular adjacency along this axis.

  • n_fft (int) – STFT window size in samples.

  • hop_length (int) – STFT hop in samples.

  • log_eps (float) – Floor added inside log10(power + log_eps) to keep the log finite at silent samples. Defaults to 1e-6.

  • mlp_features (Sequence[int]) – Hidden sizes of the rotation-invariant MLP. Output dim per band is mlp_features[-1].

  • rotation_offsets (Sequence[int]) – Circular electrode offsets used to enforce approximate rotation invariance. Defaults to (-1, 0, 1).

  • pooling (str) – Pool reduction across the rotation rolls. Defaults to "mean".

  • block_channels (Sequence[int]) – Channel count per TDS convolutional block. The model’s internal num_features = num_bands * mlp_features[-1] must be evenly divisible by each entry.

  • kernel_width (int) – Temporal kernel size of each TDS convolutional block.

  • log_softmax (bool) – If True, apply log_softmax() to the emissions. Disabled by default (braindecode models return logits).

  • activation (type[Module]) – Activation class used inside the rotation-invariant MLP and the TDS blocks. Defaults to ReLU (matches upstream emg2qwerty). Pass any non-parametrized activation class (GELU, SiLU, …) for ablations.

  • drop_prob (float) – Dropout probability applied inside each TDS feedforward block, once after the activation between the two Linear layers and again after the second Linear. Default 0.0 matches the upstream paper recipe (no dropout). Set > 0 for regularized training.

  • spec_augment (bool) – If True, apply SpecAugment [park2019specaug] time/frequency masking on the log-spectrogram during training only. Disabled in eval mode and absent from the parameter / state-dict count. Defaults to False; set to True to match the upstream emg2qwerty paper recipe.

  • n_time_masks (int) – Maximum number of time masks applied per call. Each forward pass samples a uniform integer in [0, n_time_masks]. Defaults to 3 (Sivakumar et al. Sec 5.2).

  • time_mask_param (int) – Maximum time-mask width in spectrogram frames. Defaults to 25.

  • n_freq_masks (int) – Maximum number of frequency masks applied per call. Each forward pass samples a uniform integer in [0, n_freq_masks]. Defaults to 2.

  • freq_mask_param (int) – Maximum frequency-mask width in STFT bins. Defaults to 4.

  • spec_augment_prob (float) – Probability of running SpecAugment on a given training batch (Bernoulli gate before sampling mask counts). Defaults to 1.0.

  • return_feature (bool) – If True, forward returns a tuple (emissions, features) instead of just emissionsbraindecode.models.BIOT-style legacy feature path. Lets configuration-driven downstream wrappers (e.g. neuroai’s DownstreamWrapperModel with model_output_key=1) pick up the encoder representation without passing a runtime kwarg. Defaults to False. Mutually compatible with the runtime return_features (plural) flag, which still wins when set to True.

  • n_times (int | None) – Number of time samples of the input window.

  • input_window_seconds (float | None) – Length of the input window in seconds.

  • chs_info (list | None) – Information about each individual EEG channel. This should be filled with info["chs"]. Refer to mne.Info for more details.

Raises:

ValueError – If some input signal-related parameters are not specified: and can not be inferred.

Notes

If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.

Examples

Build the model and run a forward pass on a 4-second sEMG batch:

import torch
from braindecode.models import EMG2QwertyNet

model = EMG2QwertyNet(
    n_outputs=99, n_chans=32, n_times=8000, sfreq=2000,
)
x = torch.randn(2, 32, 8000)
emissions = model(x)

Compute a CTC loss on the emissions:

import torch.nn as nn
import torch.nn.functional as F

log_probs = F.log_softmax(emissions, dim=-1).transpose(0, 1)
input_lengths = model.compute_output_lengths(
    torch.tensor([8000, 8000])
)
targets = torch.randint(0, 98, (2, 20), dtype=torch.long)
target_lengths = torch.tensor([20, 15], dtype=torch.int32)
loss = nn.CTCLoss(blank=98, zero_infinity=True)(
    log_probs, targets, input_lengths, target_lengths,
)

References

[emg2qwerty2024]

Sivakumar, V., Seely, J., Du, A., Bittner, S. R., Berenzweig, A., Bolarinwa, A., Gramfort, A., Mandel, M. I., 2024. emg2qwerty: A Large Dataset with Baselines for Touch Typing using Surface Electromyography. Advances in Neural Information Processing Systems 37, 91373-91389.

[graves2006ctc]

Graves, A., Fernandez, S., Gomez, F., Schmidhuber, J., 2006. Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks. Proc. ICML, 369-376.

[hannun2019tds]

Hannun, A., Lee, A., Xu, Q., Collobert, R., 2019. Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions. arXiv:1904.02619.

[park2019specaug] (1,2)

Park, D. S. et al., 2019. SpecAugment: a simple data augmentation method for automatic speech recognition. Proc. Interspeech, 2613-2617.

Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:

pip install braindecode[hub]

Pushing a model to the Hub:

from braindecode.models import EMG2QwertyNet

# Train your model
model = EMG2QwertyNet(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...

# Push to the Hub
model.push_to_hub(
    repo_id="username/my-emg2qwertynet-model",
    commit_message="Initial model upload",
)

Loading a model from the Hub:

from braindecode.models import EMG2QwertyNet

# Load pretrained model
model = EMG2QwertyNet.from_pretrained("username/my-emg2qwertynet-model")

# Load with a different number of outputs (head is rebuilt automatically)
model = EMG2QwertyNet.from_pretrained("username/my-emg2qwertynet-model", n_outputs=4)

Extracting features and replacing the head:

import torch

x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]

# Replace the classification head
model.reset_head(n_outputs=10)

Saving and restoring full configuration:

import json

config = model.get_config()            # all __init__ params
with open("config.json", "w") as f:
    json.dump(config, f)

model2 = EMG2QwertyNet.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See Loading and Adapting Pretrained Foundation Models for a complete tutorial.

Methods

compute_output_lengths(input_lengths)[source]#

Map per-sample input lengths to CTC emission lengths.

T_out = (T - n_fft) // hop_length + 1 - n_conv_blocks * (kernel_width - 1), clamped to zero.

Parameters:

input_lengths (Tensor) – The description is missing.

Return type:

Tensor

forward(x, return_features=False)[source]#

Run the full pipeline.

Parameters:
  • x (Tensor) – Raw EMG of shape (batch, n_chans=32, n_times). n_times must be at least the encoder’s receptive field, n_fft + n_conv_blocks * (kernel_width - 1) * hop_length.

  • return_features (bool) – If True, return a dict with the encoder representation instead of the classification emissions. The encoder is the full TDS-Conv stack up to (but not including) self.final_layer — i.e. what downstream wrappers want when they apply their own probe/aggregation. Matches the BIOT / signal-JEPA convention so the same neuroai DownstreamWrapperModel(model_output_key="features") can consume it. Wins over the constructor-time return_feature flag when set.

Returns:

Default (return_features=False, init return_feature=False): torch.Tensor of shape (batch, T_out, n_outputs). Log-probabilities if log_softmax=True, otherwise logits.

If runtime return_features=True: dict with "features" (shape (batch, T_out, num_features), where num_features = num_bands * mlp_features[-1]) and "cls_token" (always None — TDS-Conv has no [CLS]).

If init return_feature=True and runtime return_features=False: tuple (emissions, features) where features has shape (batch, T_out, num_features). Same layout BIOT exposes for configuration-driven feature extraction (e.g. neuroai’s model_output_key=1).

Return type:

Tensor | dict[str, Tensor | None] | tuple[Tensor, Tensor]

get_output_shape()[source]#

Shape of forward output for a batch of size 1.

Uses the user-supplied n_times so this method’s reported shape is consistent with what forward() accepts. If the configured n_times is below the encoder’s receptive field, forward() would raise; we mirror that here. Falls back to the receptive-field minimum only when n_times was not set at construction time.

Returns:

output_shape – shape of the network output for batch_size==1 (1, …)

Return type:

tuple[int, ...]

reset_head(n_outputs)[source]#

Replace the classification head for a new vocabulary size.

The replacement Linear inherits the existing head’s dtype and device so a subsequent forward() does not crash after model.double() or model.to(device). The captured init config (get_config()) is also kept in sync so save/load round-trips rebuild the new head.

Parameters:

n_outputs (int) – New number of output classes.

Return type:

None

Examples

>>> from braindecode.models import BENDR
>>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4)
>>> model.reset_head(10)
>>> model.n_outputs
10

Added in version 1.4.