braindecode.models.EMG2QwertyNet#
- class braindecode.models.EMG2QwertyNet(n_outputs=99, n_chans=32, sfreq=2000.0, num_bands=2, electrodes_per_band=16, n_fft=64, hop_length=16, log_eps=1e-06, mlp_features=(384, ), rotation_offsets=(-1, 0, 1), pooling='mean', block_channels=(24, 24, 24, 24), kernel_width=32, log_softmax=False, activation=<class 'torch.nn.modules.activation.ReLU'>, drop_prob=0.0, n_times=None, input_window_seconds=None, chs_info=None)[source]#
Decoder mapping surface electromyography (sEMG) to keystrokes (emg2qwerty) [emg2qwerty2024].
Convolution
Time-Depth-Separable (TDS) [hannun2019tds] convolutional encoder followed by a Connectionist Temporal Classification (CTC) head [graves2006ctc]. Takes raw 32-channel sEMG (2 wristbands × 16 electrodes) at 2 kHz and emits per-frame scores over the 99-class typing vocabulary (98 keys + 1 CTC blank).
Pipeline
Log-spectrogram front-end: per-channel Short-Time Fourier Transform (
stft()) with Hann window,center=False, then squared magnitude andlog10(p + log_eps). With the defaults (n_fft=64,hop_length=16,sfreq=2000) the output frame rate is 125 Hz. No trainable parameters.Spectrogram BatchNorm:
BatchNorm2dover the(batch, freq, time)slice for each of thenum_bands × electrodes_per_bandchannels.Per-band rotation-invariant multi-layer perceptron (MLP): for each band, a shared MLP is applied to circular rolls
(-1, 0, +1)of the electrode axis, then mean-pooled.TDS convolutional encoder: stack of
len(block_channels)TDS conv blocks interleaved with feedforward blocks. No temporal padding, so each conv block stripskernel_width - 1frames.Linear classification head:
Linearprojecting ton_outputs, optionally followed bylog_softmax()(off by default; braindecode models return logits).
Output
Returns
(batch, T_out, n_outputs). Withn_times=8000and defaults,T_out=373. ForCTCLoss, transpose to(T_out, batch, n_outputs); usecompute_output_lengths()for emission lengths.Paper training recipe
Loss:
CTCLosson log-softmax outputs.Vocabulary: 98 keys + 1 blank (
n_outputs = 99).Optimizer:
Adam, lr 1e-3, weight decay 0.Schedule: 10-epoch linear warmup from lr 1e-8, then cosine annealing to 1e-6 over 150 epochs. The slow warmup is required. Without it, CTC collapses to all-blank within one epoch (a trivial local minimum).
Augmentation: per-band electrode rotations by -1/0/+1 positions, ±60-sample temporal jitter, and SpecAugment [park2019specaug] on the log-spectrogram.
Decoding: greedy CTC. Upstream also reports a 6-gram KenLM beam decoder, not ported here.
Warning
The rotation-invariant MLP assumes circular adjacency of the electrodes within each band (the wristband geometry of the paper,
electrodes_per_band=16). For arbitrary EEG montages the symmetry does not hold and this model should not be used as-is.Warning
License: noncommercial use only. This module is a derivative of Meta’s reference implementation released under CC BY-NC 4.0, the same license as the upstream repository. Not covered by braindecode’s BSD-3 license. Must not be used in commercial products or services. Pretrained weights from the upstream release carry the same restriction.
Added in version 1.5.
- Parameters:
n_outputs (
int) – Vocabulary size for CTC, including the blank class. Defaults to99(98 keys + 1 blank).n_chans (
int) – Number of EMG channels. Must equalnum_bands * electrodes_per_band(default32=2 * 16).sfreq (
float) – Sampling frequency in Hz. Defaults to2000.n_fftandhop_lengthdefaults are calibrated for this rate; pass matching values when changingsfreq.num_bands (
int) – Number of EMG bands (e.g. one per wristband). Defaults to2.electrodes_per_band (
int) – Number of electrodes per band. Defaults to16. The rotation-invariant MLP assumes circular adjacency along this axis.n_fft (
int) – STFT window size in samples.hop_length (
int) – STFT hop in samples.log_eps (
float) – Floor added insidelog10(power + log_eps)to keep the log finite at silent samples. Defaults to1e-6.mlp_features (
Sequence[int]) – Hidden sizes of the rotation-invariant MLP. Output dim per band ismlp_features[-1].rotation_offsets (
Sequence[int]) – Circular electrode offsets used to enforce approximate rotation invariance. Defaults to(-1, 0, 1).pooling (
str) – Pool reduction across the rotation rolls. Defaults to"mean".block_channels (
Sequence[int]) – Channel count per TDS convolutional block. The model’s internalnum_features = num_bands * mlp_features[-1]must be evenly divisible by each entry.kernel_width (
int) – Temporal kernel size of each TDS convolutional block.log_softmax (
bool) – IfTrue, applylog_softmax()to the emissions. Disabled by default (braindecode models return logits).activation (
type[Module]) – Activation class used inside the rotation-invariant MLP and the TDS blocks. Defaults toReLU(matches upstream emg2qwerty). Pass any non-parametrized activation class (GELU,SiLU, …) for ablations.drop_prob (
float) – Dropout probability applied inside each TDS feedforward block, once after the activation between the twoLinearlayers and again after the secondLinear. Default0.0matches the upstream paper recipe (no dropout). Set> 0for regularized training.n_times (
int|None) – Number of time samples of the input window.input_window_seconds (
float|None) – Length of the input window in seconds.chs_info (
list|None) – Information about each individual EEG channel. This should be filled withinfo["chs"]. Refer tomne.Infofor more details.
- Raises:
ValueError – If some input signal-related parameters are not specified: and can not be inferred.
Notes
If some input signal-related parameters are not specified, there will be an attempt to infer them from the other parameters.
Examples
Build the model and run a forward pass on a 4-second sEMG batch:
import torch from braindecode.models import EMG2QwertyNet model = EMG2QwertyNet( n_outputs=99, n_chans=32, n_times=8000, sfreq=2000, ) x = torch.randn(2, 32, 8000) emissions = model(x)
Compute a CTC loss on the emissions:
import torch.nn as nn import torch.nn.functional as F log_probs = F.log_softmax(emissions, dim=-1).transpose(0, 1) input_lengths = model.compute_output_lengths( torch.tensor([8000, 8000]) ) targets = torch.randint(0, 98, (2, 20), dtype=torch.long) target_lengths = torch.tensor([20, 15], dtype=torch.int32) loss = nn.CTCLoss(blank=98, zero_infinity=True)( log_probs, targets, input_lengths, target_lengths, )
References
[emg2qwerty2024]Sivakumar, V., Seely, J., Du, A., Bittner, S. R., Berenzweig, A., Bolarinwa, A., Gramfort, A., Mandel, M. I., 2024. emg2qwerty: A Large Dataset with Baselines for Touch Typing using Surface Electromyography. Advances in Neural Information Processing Systems 37, 91373-91389.
[graves2006ctc]Graves, A., Fernandez, S., Gomez, F., Schmidhuber, J., 2006. Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks. Proc. ICML, 369-376.
[hannun2019tds]Hannun, A., Lee, A., Xu, Q., Collobert, R., 2019. Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions. arXiv:1904.02619.
[park2019specaug]Park, D. S. et al., 2019. SpecAugment: a simple data augmentation method for automatic speech recognition. Proc. Interspeech, 2613-2617.
Hugging Face Hub integration
When the optional
huggingface_hubpackage is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:pip install braindecode[hub]
Pushing a model to the Hub:
from braindecode.models import EMG2QwertyNet # Train your model model = EMG2QwertyNet(n_chans=22, n_outputs=4, n_times=1000) # ... training code ... # Push to the Hub model.push_to_hub( repo_id="username/my-emg2qwertynet-model", commit_message="Initial model upload", )
Loading a model from the Hub:
from braindecode.models import EMG2QwertyNet # Load pretrained model model = EMG2QwertyNet.from_pretrained("username/my-emg2qwertynet-model") # Load with a different number of outputs (head is rebuilt automatically) model = EMG2QwertyNet.from_pretrained("username/my-emg2qwertynet-model", n_outputs=4)
Extracting features and replacing the head:
import torch x = torch.randn(1, model.n_chans, model.n_times) # Extract encoder features (consistent dict across all models) out = model(x, return_features=True) features = out["features"] # Replace the classification head model.reset_head(n_outputs=10)
Saving and restoring full configuration:
import json config = model.get_config() # all __init__ params with open("config.json", "w") as f: json.dump(config, f) model2 = EMG2QwertyNet.from_config(config) # reconstruct (no weights)
All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.
See Loading and Adapting Pretrained Foundation Models for a complete tutorial.
Methods
- compute_output_lengths(input_lengths)[source]#
Map per-sample input lengths to CTC emission lengths.
T_out = (T - n_fft) // hop_length + 1 - n_conv_blocks * (kernel_width - 1), clamped to zero.- Parameters:
input_lengths (
Tensor) – The description is missing.- Return type:
Tensor
- forward(x)[source]#
Run the full pipeline.
- Parameters:
x (
Tensor) – Raw EMG of shape(batch, n_chans=32, n_times).n_timesmust be at least the encoder’s receptive field,n_fft + n_conv_blocks * (kernel_width - 1) * hop_length.- Returns:
emissions – Shape
(batch, T_out, n_outputs). Log-probabilities iflog_softmax=True, otherwise logits.- Return type:
Tensor
- get_output_shape()[source]#
Shape of
forwardoutput for a batch of size 1.Uses the user-supplied
n_timesso this method’s reported shape is consistent with whatforward()accepts. If the configuredn_timesis below the encoder’s receptive field,forward()would raise; we mirror that here. Falls back to the receptive-field minimum only whenn_timeswas not set at construction time.
- reset_head(n_outputs)[source]#
Replace the classification head for a new vocabulary size.
The replacement
Linearinherits the existing head’s dtype and device so a subsequentforward()does not crash aftermodel.double()ormodel.to(device). The captured init config (get_config()) is also kept in sync so save/load round-trips rebuild the new head.Examples
>>> from braindecode.models import BENDR >>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4) >>> model.reset_head(10) >>> model.n_outputs 10
Added in version 1.4.