braindecode.modules.PatchTokenizer#
- class braindecode.modules.PatchTokenizer(patch_size, n_times, emb_dim=None, learnable=False)[source]#
Tokenize an EEG signal into non-overlapping temporal patches.
Transforms
(batch, n_chans, n_times)into(batch, n_chans, n_patches, patch_dim)by splitting the time axis into non-overlapping patches ofpatch_sizesamples. This is the shared patch / “tokenization” step used by transformer EEG foundation models (e.g. LaBraM, CBraMod, EEG-DINO).As in the filter-bank models (
FBCNet,FBMSNet), whenn_timesis not a multiple ofpatch_sizethe input is right zero-padded (a warning is emitted at construction); it is never an error.Two modes:
non-learnable (
learnable=False, default): a pure reshape, sopatch_dim == patch_sizeand the raw samples of each patch are kept (the patch embedding, if any, lives in the model).learnable (
learnable=True): a stridedConv1d(kernel and stride equal topatch_size, applied per channel) maps each patch toemb_dimfeatures, sopatch_dim == emb_dim.
- Parameters:
patch_size (int) – Number of time samples per patch.
n_times (int) – Number of time samples of the input, used to set up the right-padding when
n_timesis not a multiple ofpatch_size.emb_dim (int, optional) – Output features per patch in learnable mode. Defaults to
patch_size. Ignored whenlearnable=False.learnable (bool, default=False) – Whether the tokenizer is a learned convolution or a fixed reshape.
Examples
>>> import torch >>> from braindecode.modules import PatchTokenizer >>> tokenizer = PatchTokenizer(patch_size=200, n_times=1000) >>> tokenizer(torch.randn(2, 19, 1000)).shape torch.Size([2, 19, 5, 200])
Methods
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.