braindecode.models.SignalJEPA#
- class braindecode.models.SignalJEPA(n_outputs=None, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, *, feature_encoder__conv_layers_spec=((8, 32, 8), (16, 2, 2), (32, 2, 2), (64, 2, 2), (64, 2, 2)), drop_prob=0.0, feature_encoder__mode='default', feature_encoder__conv_bias=False, activation=<class 'torch.nn.modules.activation.GELU'>, pos_encoder__spat_dim=30, pos_encoder__time_dim=34, pos_encoder__sfreq_features=1.0, pos_encoder__spat_kwargs=None, transformer__d_model=64, transformer__num_encoder_layers=8, transformer__num_decoder_layers=4, transformer__nhead=8)[source]#
Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]
Convolution Channel Foundation Model
This model is not meant for classification but for SSL pre-training. Its output shape depends on the input shape. For classification purposes, three variants of this model are available:
The classification architectures can either be instantiated from scratch (random parameters) or from a pre-trained
SignalJEPAmodel.Added in version 0.9.
References
[1]Guetschel, P., Moreau, T., & Tangermann, M. (2024). S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention. In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
Methods
- forward(X, ch_idxs=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.