braindecode package#
- class braindecode.EEGClassifier(module, *args, criterion=<class 'torch.nn.modules.loss.CrossEntropyLoss'>, cropped=False, callbacks=None, iterator_train__shuffle=True, iterator_train__drop_last=True, aggregate_predictions=True, **kwargs)[source]#
Bases:
_EEGNeuralNet
,NeuralNetClassifier
Classifier that does not assume softmax activation. Calls loss function directly without applying log or anything.
- Parameters:
module (str or torch Module (class or instance)) – A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.criterion (torch criterion (class, default=torch.nn.NLLLoss)) – Negative log likelihood loss. Note that the module should return probabilities, the log is applied during
get_loss
.classes (None or list (default=None)) – If None, the
classes_
attribute will be inferred from they
data passed tofit
. If a non-empty list is passed, that list will be returned asclasses_
. If the initial skorch behavior should be restored, i.e. raising anAttributeError
, pass an empty list.optimizer (torch optim (class, default=torch.optim.SGD)) – The uninitialized optimizer (update rule) used to optimize the module
lr (float (default=0.01)) – Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.max_epochs (int (default=10)) – The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.batch_size (int (default=128)) – Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.iterator_train (torch DataLoader) – The default PyTorch
DataLoader
used for training data.iterator_valid (torch DataLoader) – The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.dataset (torch Dataset (default=skorch.dataset.Dataset)) – The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.train_split (None or callable (default=skorch.dataset.ValidSplit(5))) –
If
None
, there is no train/validation split. Else,train_split
should be a function or callable that is called with X and y data and should return the tupledataset_train, dataset_valid
. The validation data may beNone
.If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks, not even the default callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).predict_nonlinearity (callable, None, or 'auto' (default='auto')) –
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.warm_start (bool (default=False)) – Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
verbose (int (default=1)) – This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
device (str, torch.device, or None (default='cpu')) – The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
compile (bool (default=False)) – If set to
True
, compile all modules usingtorch.compile
. For this to work, the installed torch version has to supporttorch.compile
. Compiled modules should work identically to non-compiled modules but should run faster on new GPU architectures (Volta and Ampere for instance). Additional arguments fortorch.compile
can be passed using the dunder notation, e.g. when initializing the net withcompile__dynamic=True
,torch.compile
will be called withdynamic=True
.use_caching (bool or 'auto' (default='auto')) – Optionally override the caching behavior of scoring callbacks. Callbacks such as
EpochScoring
andBatchScoring
allow to cache the inference call to save time when calculating scores during training at the expense of memory. In certain situations, e.g. when memory is tight, you may want to disable caching. As it is cumbersome to change the setting on each callback individually, this parameter allows to override their behavior globally. By default ('auto'
), the callbacks will determine if caching is used or not. If this argument is set toFalse
, caching will be disabled on all callbacks. If set toTrue
, caching will be enabled on all callbacks. Implementation note: It is the job of the callbacks to honor this setting.torch_load_kwargs (dict or None (default=None)) –
Additional arguments that will be passed to torch.load when load pickled parameters.
In particular, this is important to because PyTorch will switch (probably in version 2.6.0) to only allow weights to be loaded for security reasons (i.e weights_only switches from False to True). As a consequence, loading pickled parameters may raise an error after upgrading torch because some types are used that are considered insecure. In skorch, we will also make that switch at the same time. To resolve the error, follow the instructions in the torch error message to designate the offending types as secure. Only do this if you trust the source of the file.
If you want to keep loading non-weight types the same way as before, please pass:
torch_load_kwargs={‘weights_only’: False}
You should be aware that this is considered insecure and should only be used if you trust the source of the file. However, this does not introduce new insecurities, it rather corresponds to the status quo from before torch made the switch.
Another way to avoid this issue is to pass use_safetensors=True when calling save_params and load_params. This avoid using pickle in favor of the safetensors format, which is secure by design.
module – Either the name of one of the braindecode models (see
braindecode.models.util.models_dict
) or directly a PyTorch module. When passing directly a torch module, uninstantiated class should be preferred, although instantiated modules will also work.cropped (bool (default=False)) – Defines whether torch model passed to this class is cropped or not. Currently used for callbacks definition.
callbacks (None or list of strings or list of Callback instances (default=None)) – More callbacks, in addition to those returned by
get_default_callbacks
. Each callback should inherit fromskorch.callbacks.Callback
. If notNone
, callbacks can be a list of strings specifying sklearn scoring functions (for scoring functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter) or a list of callbacks where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).iterator_train__shuffle (bool (default=True)) – Defines whether train dataset will be shuffled. As skorch does not shuffle the train dataset by default this one overwrites this option.
aggregate_predictions (bool (default=True)) – Whether to average cropped predictions to obtain window predictions. Used only in the cropped mode.
- prefixes_#
Contains the prefixes to special parameters. E.g., since there is the
'optimizer'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
. Some prefixes are populated dynamically, based on what modules and criteria are defined.
- cuda_dependent_attributes_#
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.
- module_#
The instantiated module.
- Type:
torch module (instance)
- criterion_#
The instantiated criterion.
- Type:
torch criterion (instance)
- callbacks_#
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- Type:
list of tuples
- _modules#
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria#
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers#
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- classes_#
A list of class labels known to the classifier.
- Type:
array, shape (n_classes, )
- doc = "Classifier that does not assume softmax activation.\n Calls loss function directly without applying log or anything.\n\n Parameters\n ----------\n module: str or torch Module (class or instance)\n Either the name of one of the braindecode models (see\n :obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.\n When passing directly a torch module, uninstantiated class should be preferred,\n although instantiated modules will also work.\n\n cropped: bool (default=False)\n Defines whether torch model passed to this class is cropped or not.\n Currently used for callbacks definition.\n\n callbacks: None or list of strings or list of Callback instances (default=None)\n More callbacks, in addition to those returned by\n ``get_default_callbacks``. Each callback should inherit from\n :class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a\n list of strings specifying `sklearn` scoring functions (for scoring\n functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)\n or a list of callbacks where the callback names are inferred from the\n class name. Name conflicts are resolved by appending a count suffix\n starting with 1, e.g. ``EpochScoring_1``. Alternatively,\n a tuple ``(name, callback)`` can be passed, where ``name``\n should be unique. Callbacks may or may not be instantiated.\n The callback name can be used to set parameters on specific\n callbacks (e.g., for the callback with name ``'print_log'``, use\n ``net.set_params(callbacks__print_log__keys_ignored=['epoch',\n 'train_loss'])``).\n\n iterator_train__shuffle: bool (default=True)\n Defines whether train dataset will be shuffled. As skorch does not\n shuffle the train dataset by default this one overwrites this option.\n\n aggregate_predictions: bool (default=True)\n Whether to average cropped predictions to obtain window predictions. Used only in the\n cropped mode.\n\n "#
- get_iterator(dataset, training=False, drop_index=True)[source]#
Get an iterator that allows to loop over the batches of the given data.
If
self.iterator_train__batch_size
and/orself.iterator_test__batch_size
are not set, useself.batch_size
instead.- Parameters:
dataset (torch Dataset (default=skorch.dataset.Dataset)) – Usually,
self.dataset
, initialized with the corresponding data, is passed toget_iterator
.training (bool (default=False)) – Whether to use
iterator_train
oriterator_test
.
- Returns:
An instantiated iterator that allows to loop over the mini-batches.
- Return type:
iterator
- get_loss(y_pred, y_true, *args, **kwargs)[source]#
Return the loss for this batch by calling NeuralNet get_loss.
- Parameters:
y_pred (torch tensor) – Predicted target values
y_true (torch tensor) – True target values.
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.training (bool (default=False)) – Whether train mode should be used or not.
- Returns:
loss – The loss value.
- Return type:
- predict(X)[source]#
Return class labels for samples in X.
- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- Returns:
y_pred
- Return type:
numpy ndarray
- predict_proba(X)[source]#
Return the output of the module’s forward method as a numpy array. In case of cropped decoding returns averaged values for each trial.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant or module’s output for each crop is needed, consider using
forward()
instead.- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- Returns:
y_proba
- Return type:
numpy ndarray
- predict_trials(X, return_targets=True)[source]#
Create trialwise predictions and optionally also return trialwise labels from cropped dataset.
- Parameters:
X (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.
return_targets (bool) – If True, additionally returns the trial targets.
- Returns:
trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.
- set_partial_fit_request(*, classes: bool | None | str = '$UNCHANGED$') EEGClassifier [source]#
Request metadata passed to the
partial_fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed topartial_fit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it topartial_fit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') EEGClassifier [source]#
Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
- class braindecode.EEGRegressor(module, *args, cropped=False, callbacks=None, iterator_train__shuffle=True, iterator_train__drop_last=True, aggregate_predictions=True, **kwargs)[source]#
Bases:
_EEGNeuralNet
,NeuralNetRegressor
Regressor that calls loss function directly.
- Parameters:
module (str or torch Module (class or instance)) – A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.criterion (torch criterion (class, default=torch.nn.MSELoss)) – Mean squared error loss.
optimizer (torch optim (class, default=torch.optim.SGD)) – The uninitialized optimizer (update rule) used to optimize the module
lr (float (default=0.01)) – Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.max_epochs (int (default=10)) – The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.batch_size (int (default=128)) – Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.iterator_train (torch DataLoader) – The default PyTorch
DataLoader
used for training data.iterator_valid (torch DataLoader) – The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.dataset (torch Dataset (default=skorch.dataset.Dataset)) – The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.train_split (None or callable (default=skorch.dataset.ValidSplit(5))) –
If
None
, there is no train/validation split. Else,train_split
should be a function or callable that is called with X and y data and should return the tupledataset_train, dataset_valid
. The validation data may beNone
.If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks, not even the default callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).predict_nonlinearity (callable, None, or 'auto' (default='auto')) –
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.warm_start (bool (default=False)) – Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
verbose (int (default=1)) – This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
device (str, torch.device, or None (default='cpu')) – The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
compile (bool (default=False)) – If set to
True
, compile all modules usingtorch.compile
. For this to work, the installed torch version has to supporttorch.compile
. Compiled modules should work identically to non-compiled modules but should run faster on new GPU architectures (Volta and Ampere for instance). Additional arguments fortorch.compile
can be passed using the dunder notation, e.g. when initializing the net withcompile__dynamic=True
,torch.compile
will be called withdynamic=True
.use_caching (bool or 'auto' (default='auto')) – Optionally override the caching behavior of scoring callbacks. Callbacks such as
EpochScoring
andBatchScoring
allow to cache the inference call to save time when calculating scores during training at the expense of memory. In certain situations, e.g. when memory is tight, you may want to disable caching. As it is cumbersome to change the setting on each callback individually, this parameter allows to override their behavior globally. By default ('auto'
), the callbacks will determine if caching is used or not. If this argument is set toFalse
, caching will be disabled on all callbacks. If set toTrue
, caching will be enabled on all callbacks. Implementation note: It is the job of the callbacks to honor this setting.torch_load_kwargs (dict or None (default=None)) –
Additional arguments that will be passed to torch.load when load pickled parameters.
In particular, this is important to because PyTorch will switch (probably in version 2.6.0) to only allow weights to be loaded for security reasons (i.e weights_only switches from False to True). As a consequence, loading pickled parameters may raise an error after upgrading torch because some types are used that are considered insecure. In skorch, we will also make that switch at the same time. To resolve the error, follow the instructions in the torch error message to designate the offending types as secure. Only do this if you trust the source of the file.
If you want to keep loading non-weight types the same way as before, please pass:
torch_load_kwargs={‘weights_only’: False}
You should be aware that this is considered insecure and should only be used if you trust the source of the file. However, this does not introduce new insecurities, it rather corresponds to the status quo from before torch made the switch.
Another way to avoid this issue is to pass use_safetensors=True when calling save_params and load_params. This avoid using pickle in favor of the safetensors format, which is secure by design.
module – Either the name of one of the braindecode models (see
braindecode.models.util.models_dict
) or directly a PyTorch module. When passing directly a torch module, uninstantiated class should be preferred, although instantiated modules will also work.cropped (bool (default=False)) – Defines whether torch model passed to this class is cropped or not. Currently used for callbacks definition.
callbacks (None or list of strings or list of Callback instances (default=None)) – More callbacks, in addition to those returned by
get_default_callbacks
. Each callback should inherit fromskorch.callbacks.Callback
. If notNone
, callbacks can be a list of strings specifying sklearn scoring functions (for scoring functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter) or a list of callbacks where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).iterator_train__shuffle (bool (default=True)) – Defines whether train dataset will be shuffled. As skorch does not shuffle the train dataset by default this one overwrites this option.
aggregate_predictions (bool (default=True)) – Whether to average cropped predictions to obtain window predictions. Used only in the cropped mode.
- prefixes_#
Contains the prefixes to special parameters. E.g., since there is the
'optimizer'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
. Some prefixes are populated dynamically, based on what modules and criteria are defined.
- cuda_dependent_attributes_#
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.
- module_#
The instantiated module.
- Type:
torch module (instance)
- criterion_#
The instantiated criterion.
- Type:
torch criterion (instance)
- callbacks_#
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- Type:
list of tuples
- _modules#
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria#
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers#
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- doc = "Regressor that calls loss function directly.\n\n Parameters\n ----------\n module: str or torch Module (class or instance)\n Either the name of one of the braindecode models (see\n :obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.\n When passing directly a torch module, uninstantiated class should be preferred,\n although instantiated modules will also work.\n\n cropped: bool (default=False)\n Defines whether torch model passed to this class is cropped or not.\n Currently used for callbacks definition.\n\n callbacks: None or list of strings or list of Callback instances (default=None)\n More callbacks, in addition to those returned by\n ``get_default_callbacks``. Each callback should inherit from\n :class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a\n list of strings specifying `sklearn` scoring functions (for scoring\n functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)\n or a list of callbacks where the callback names are inferred from the\n class name. Name conflicts are resolved by appending a count suffix\n starting with 1, e.g. ``EpochScoring_1``. Alternatively,\n a tuple ``(name, callback)`` can be passed, where ``name``\n should be unique. Callbacks may or may not be instantiated.\n The callback name can be used to set parameters on specific\n callbacks (e.g., for the callback with name ``'print_log'``, use\n ``net.set_params(callbacks__print_log__keys_ignored=['epoch',\n 'train_loss'])``).\n\n iterator_train__shuffle: bool (default=True)\n Defines whether train dataset will be shuffled. As skorch does not\n shuffle the train dataset by default this one overwrites this option.\n\n aggregate_predictions: bool (default=True)\n Whether to average cropped predictions to obtain window predictions. Used only in the\n cropped mode.\n\n "#
- fit(X, y=None, **kwargs)[source]#
Initialize and fit the module.
If the module was already initialized, by calling fit, the module will be re-initialized (unless
warm_start
is True). If possible, signal-related parameters are inferred from the data and passed to the module at initialisation. Depending on the type of input passed, the following parameters are inferred:mne.Epochs:
n_times
,n_chans
,n_outputs
,chs_info
,sfreq
,input_window_seconds
numpy array:
n_times
,n_chans
,n_outputs
WindowsDataset with
targets_from='metadata'
(or BaseConcatDataset of such datasets):n_times
,n_chans
,n_outputs
other Dataset:
n_times
,n_chans
other types: no parameters are inferred.
- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
mne.Epochs
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.y (target data, compatible with skorch.dataset.Dataset) – The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.**fit_params (dict) – Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
- get_iterator(dataset, training=False, drop_index=True)[source]#
Get an iterator that allows to loop over the batches of the given data.
If
self.iterator_train__batch_size
and/orself.iterator_test__batch_size
are not set, useself.batch_size
instead.- Parameters:
dataset (torch Dataset (default=skorch.dataset.Dataset)) – Usually,
self.dataset
, initialized with the corresponding data, is passed toget_iterator
.training (bool (default=False)) – Whether to use
iterator_train
oriterator_test
.
- Returns:
An instantiated iterator that allows to loop over the mini-batches.
- Return type:
iterator
- predict_proba(X)[source]#
Return the output of the module’s forward method as a numpy array. In case of cropped decoding returns averaged values for each trial.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant or module’s output for each crop is needed, consider using
forward()
instead.- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Warning
Regressors predict regression targets, so output of this method can’t be interpreted as probabilities. We advise you to use predict method instead of predict_proba.
- Returns:
y_proba
- Return type:
numpy ndarray
- predict_trials(X, return_targets=True)[source]#
Create trialwise predictions and optionally also return trialwise labels from cropped dataset.
- Parameters:
X (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.
return_targets (bool) – If True, additionally returns the trial targets.
- Returns:
trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.
- set_partial_fit_request(*, classes: bool | None | str = '$UNCHANGED$') EEGRegressor [source]#
Request metadata passed to the
partial_fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed topartial_fit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it topartial_fit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') EEGRegressor [source]#
Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
Subpackages#
- braindecode.augmentation package
AugmentedDataLoader
BandstopFilter
ChannelsDropout
ChannelsShuffle
ChannelsSymmetry
Compose
FTSurrogate
FrequencyShift
GaussianNoise
IdentityTransform
MaskEncoding
Mixup
SegmentationReconstruction
SensorsRotation
SensorsXRotation
SensorsYRotation
SensorsZRotation
SignFlip
SmoothTimeMask
TimeReverse
Transform
- Submodules
- braindecode.augmentation.base module
- braindecode.augmentation.functional module
- braindecode.augmentation.transforms module
- braindecode.datasets package
BCICompetitionIVDataset4
BIDSDataset
BIDSDataset.acquisitions
BIDSDataset.check
BIDSDataset.datatypes
BIDSDataset.descriptions
BIDSDataset.extensions
BIDSDataset.n_jobs
BIDSDataset.preload
BIDSDataset.processings
BIDSDataset.recordings
BIDSDataset.root
BIDSDataset.runs
BIDSDataset.sessions
BIDSDataset.spaces
BIDSDataset.splits
BIDSDataset.subjects
BIDSDataset.suffixes
BIDSDataset.tasks
BIDSEpochsDataset
BNCI2014001
BaseConcatDataset
BaseDataset
HGD
MOABBDataset
NMT
SleepPhysionet
SleepPhysionetChallenge2018
TUH
TUHAbnormal
WindowsDataset
create_from_X_y()
create_from_mne_epochs()
create_from_mne_raw()
- Submodules
- braindecode.datasets.base module
- braindecode.datasets.bbci module
- braindecode.datasets.bcicomp module
- braindecode.datasets.bids module
BIDSDataset
BIDSDataset.acquisitions
BIDSDataset.check
BIDSDataset.cumulative_sizes
BIDSDataset.datasets
BIDSDataset.datatypes
BIDSDataset.descriptions
BIDSDataset.extensions
BIDSDataset.n_jobs
BIDSDataset.preload
BIDSDataset.processings
BIDSDataset.recordings
BIDSDataset.root
BIDSDataset.runs
BIDSDataset.sessions
BIDSDataset.spaces
BIDSDataset.splits
BIDSDataset.subjects
BIDSDataset.suffixes
BIDSDataset.tasks
BIDSEpochsDataset
- braindecode.datasets.mne module
- braindecode.datasets.moabb module
- braindecode.datasets.nmt module
- braindecode.datasets.sleep_physio_challe_18 module
- braindecode.datasets.sleep_physionet module
- braindecode.datasets.tuh module
- braindecode.datasets.xy module
- braindecode.datautil package
- braindecode.functional package
- braindecode.models package
- Submodules
- braindecode.models.atcnet module
- braindecode.models.attentionbasenet module
- braindecode.models.base module
EEGModuleMixin
EEGModuleMixin.chs_info
EEGModuleMixin.get_output_shape()
EEGModuleMixin.get_torchinfo_statistics()
EEGModuleMixin.input_shape
EEGModuleMixin.input_window_seconds
EEGModuleMixin.load_state_dict()
EEGModuleMixin.mapping
EEGModuleMixin.n_chans
EEGModuleMixin.n_outputs
EEGModuleMixin.n_times
EEGModuleMixin.sfreq
EEGModuleMixin.to_dense_prediction_model()
deprecated_args()
- braindecode.models.biot module
- braindecode.models.contrawr module
- braindecode.models.ctnet module
- braindecode.models.deep4 module
- braindecode.models.deepsleepnet module
- braindecode.models.eegconformer module
- braindecode.models.eeginception_erp module
- braindecode.models.eeginception_mi module
- braindecode.models.eegitnet module
- braindecode.models.eegminer module
- braindecode.models.eegnet module
- braindecode.models.eegnex module
- braindecode.models.eegresnet module
- braindecode.models.eegsimpleconv module
- braindecode.models.eegtcnet module
- braindecode.models.fbcnet module
- braindecode.models.fblightconvnet module
- braindecode.models.fbmsnet module
- braindecode.models.hybrid module
- braindecode.models.ifnet module
- braindecode.models.labram module
- braindecode.models.msvtnet module
- braindecode.models.sccnet module
- braindecode.models.shallow_fbcsp module
- braindecode.models.signal_jepa module
- braindecode.models.sinc_shallow module
- braindecode.models.sleep_stager_blanco_2020 module
- braindecode.models.sleep_stager_chambon_2018 module
- braindecode.models.sleep_stager_eldele_2021 module
- braindecode.models.sparcnet module
- braindecode.models.syncnet module
- braindecode.models.tcn module
- braindecode.models.tidnet module
- braindecode.models.tsinception module
- braindecode.models.usleep module
- braindecode.models.util module
- braindecode.modules package
- Submodules
- braindecode.modules.activation module
- braindecode.modules.attention module
- braindecode.modules.blocks module
- braindecode.modules.convolution module
- braindecode.modules.filter module
- braindecode.modules.layers module
- braindecode.modules.linear module
- braindecode.modules.parametrization module
- braindecode.modules.stats module
- braindecode.modules.util module
- braindecode.modules.wrapper module
- braindecode.preprocessing package
Crop
DropChannels
Filter
Pick
Preprocessor
Resample
SetEEGReference
create_fixed_length_windows()
create_windows_from_events()
create_windows_from_target_channels()
exponential_moving_demean()
exponential_moving_standardize()
filterbank()
preprocess()
- Submodules
- braindecode.preprocessing.mne_preprocess module
- braindecode.preprocessing.preprocess module
- braindecode.preprocessing.windowers module
- braindecode.samplers package
BalancedSequenceSampler
DistributedRecordingSampler
DistributedRelativePositioningSampler
RecordingSampler
RelativePositioningSampler
SequenceSampler
- Submodules
- braindecode.samplers.base module
- braindecode.samplers.ssl module
- braindecode.training package
- braindecode.visualization package
Submodules#
braindecode.classifier module#
- class braindecode.classifier.EEGClassifier(module, *args, criterion=<class 'torch.nn.modules.loss.CrossEntropyLoss'>, cropped=False, callbacks=None, iterator_train__shuffle=True, iterator_train__drop_last=True, aggregate_predictions=True, **kwargs)[source]#
Bases:
_EEGNeuralNet
,NeuralNetClassifier
Classifier that does not assume softmax activation. Calls loss function directly without applying log or anything.
- Parameters:
module (str or torch Module (class or instance)) – A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.criterion (torch criterion (class, default=torch.nn.NLLLoss)) – Negative log likelihood loss. Note that the module should return probabilities, the log is applied during
get_loss
.classes (None or list (default=None)) – If None, the
classes_
attribute will be inferred from they
data passed tofit
. If a non-empty list is passed, that list will be returned asclasses_
. If the initial skorch behavior should be restored, i.e. raising anAttributeError
, pass an empty list.optimizer (torch optim (class, default=torch.optim.SGD)) – The uninitialized optimizer (update rule) used to optimize the module
lr (float (default=0.01)) – Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.max_epochs (int (default=10)) – The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.batch_size (int (default=128)) – Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.iterator_train (torch DataLoader) – The default PyTorch
DataLoader
used for training data.iterator_valid (torch DataLoader) – The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.dataset (torch Dataset (default=skorch.dataset.Dataset)) – The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.train_split (None or callable (default=skorch.dataset.ValidSplit(5))) –
If
None
, there is no train/validation split. Else,train_split
should be a function or callable that is called with X and y data and should return the tupledataset_train, dataset_valid
. The validation data may beNone
.If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks, not even the default callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).predict_nonlinearity (callable, None, or 'auto' (default='auto')) –
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.warm_start (bool (default=False)) – Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
verbose (int (default=1)) – This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
device (str, torch.device, or None (default='cpu')) – The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
compile (bool (default=False)) – If set to
True
, compile all modules usingtorch.compile
. For this to work, the installed torch version has to supporttorch.compile
. Compiled modules should work identically to non-compiled modules but should run faster on new GPU architectures (Volta and Ampere for instance). Additional arguments fortorch.compile
can be passed using the dunder notation, e.g. when initializing the net withcompile__dynamic=True
,torch.compile
will be called withdynamic=True
.use_caching (bool or 'auto' (default='auto')) – Optionally override the caching behavior of scoring callbacks. Callbacks such as
EpochScoring
andBatchScoring
allow to cache the inference call to save time when calculating scores during training at the expense of memory. In certain situations, e.g. when memory is tight, you may want to disable caching. As it is cumbersome to change the setting on each callback individually, this parameter allows to override their behavior globally. By default ('auto'
), the callbacks will determine if caching is used or not. If this argument is set toFalse
, caching will be disabled on all callbacks. If set toTrue
, caching will be enabled on all callbacks. Implementation note: It is the job of the callbacks to honor this setting.torch_load_kwargs (dict or None (default=None)) –
Additional arguments that will be passed to torch.load when load pickled parameters.
In particular, this is important to because PyTorch will switch (probably in version 2.6.0) to only allow weights to be loaded for security reasons (i.e weights_only switches from False to True). As a consequence, loading pickled parameters may raise an error after upgrading torch because some types are used that are considered insecure. In skorch, we will also make that switch at the same time. To resolve the error, follow the instructions in the torch error message to designate the offending types as secure. Only do this if you trust the source of the file.
If you want to keep loading non-weight types the same way as before, please pass:
torch_load_kwargs={‘weights_only’: False}
You should be aware that this is considered insecure and should only be used if you trust the source of the file. However, this does not introduce new insecurities, it rather corresponds to the status quo from before torch made the switch.
Another way to avoid this issue is to pass use_safetensors=True when calling save_params and load_params. This avoid using pickle in favor of the safetensors format, which is secure by design.
module – Either the name of one of the braindecode models (see
braindecode.models.util.models_dict
) or directly a PyTorch module. When passing directly a torch module, uninstantiated class should be preferred, although instantiated modules will also work.cropped (bool (default=False)) – Defines whether torch model passed to this class is cropped or not. Currently used for callbacks definition.
callbacks (None or list of strings or list of Callback instances (default=None)) – More callbacks, in addition to those returned by
get_default_callbacks
. Each callback should inherit fromskorch.callbacks.Callback
. If notNone
, callbacks can be a list of strings specifying sklearn scoring functions (for scoring functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter) or a list of callbacks where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).iterator_train__shuffle (bool (default=True)) – Defines whether train dataset will be shuffled. As skorch does not shuffle the train dataset by default this one overwrites this option.
aggregate_predictions (bool (default=True)) – Whether to average cropped predictions to obtain window predictions. Used only in the cropped mode.
- prefixes_#
Contains the prefixes to special parameters. E.g., since there is the
'optimizer'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
. Some prefixes are populated dynamically, based on what modules and criteria are defined.
- cuda_dependent_attributes_#
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.
- module_#
The instantiated module.
- Type:
torch module (instance)
- criterion_#
The instantiated criterion.
- Type:
torch criterion (instance)
- callbacks_#
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- Type:
list of tuples
- _modules#
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria#
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers#
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- classes_#
A list of class labels known to the classifier.
- Type:
array, shape (n_classes, )
- doc = "Classifier that does not assume softmax activation.\n Calls loss function directly without applying log or anything.\n\n Parameters\n ----------\n module: str or torch Module (class or instance)\n Either the name of one of the braindecode models (see\n :obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.\n When passing directly a torch module, uninstantiated class should be preferred,\n although instantiated modules will also work.\n\n cropped: bool (default=False)\n Defines whether torch model passed to this class is cropped or not.\n Currently used for callbacks definition.\n\n callbacks: None or list of strings or list of Callback instances (default=None)\n More callbacks, in addition to those returned by\n ``get_default_callbacks``. Each callback should inherit from\n :class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a\n list of strings specifying `sklearn` scoring functions (for scoring\n functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)\n or a list of callbacks where the callback names are inferred from the\n class name. Name conflicts are resolved by appending a count suffix\n starting with 1, e.g. ``EpochScoring_1``. Alternatively,\n a tuple ``(name, callback)`` can be passed, where ``name``\n should be unique. Callbacks may or may not be instantiated.\n The callback name can be used to set parameters on specific\n callbacks (e.g., for the callback with name ``'print_log'``, use\n ``net.set_params(callbacks__print_log__keys_ignored=['epoch',\n 'train_loss'])``).\n\n iterator_train__shuffle: bool (default=True)\n Defines whether train dataset will be shuffled. As skorch does not\n shuffle the train dataset by default this one overwrites this option.\n\n aggregate_predictions: bool (default=True)\n Whether to average cropped predictions to obtain window predictions. Used only in the\n cropped mode.\n\n "#
- get_iterator(dataset, training=False, drop_index=True)[source]#
Get an iterator that allows to loop over the batches of the given data.
If
self.iterator_train__batch_size
and/orself.iterator_test__batch_size
are not set, useself.batch_size
instead.- Parameters:
dataset (torch Dataset (default=skorch.dataset.Dataset)) – Usually,
self.dataset
, initialized with the corresponding data, is passed toget_iterator
.training (bool (default=False)) – Whether to use
iterator_train
oriterator_test
.
- Returns:
An instantiated iterator that allows to loop over the mini-batches.
- Return type:
iterator
- get_loss(y_pred, y_true, *args, **kwargs)[source]#
Return the loss for this batch by calling NeuralNet get_loss.
- Parameters:
y_pred (torch tensor) – Predicted target values
y_true (torch tensor) – True target values.
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.training (bool (default=False)) – Whether train mode should be used or not.
- Returns:
loss – The loss value.
- Return type:
- predict(X)[source]#
Return class labels for samples in X.
- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- Returns:
y_pred
- Return type:
numpy ndarray
- predict_proba(X)[source]#
Return the output of the module’s forward method as a numpy array. In case of cropped decoding returns averaged values for each trial.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant or module’s output for each crop is needed, consider using
forward()
instead.- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- Returns:
y_proba
- Return type:
numpy ndarray
- predict_trials(X, return_targets=True)[source]#
Create trialwise predictions and optionally also return trialwise labels from cropped dataset.
- Parameters:
X (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.
return_targets (bool) – If True, additionally returns the trial targets.
- Returns:
trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.
- set_partial_fit_request(*, classes: bool | None | str = '$UNCHANGED$') EEGClassifier [source]#
Request metadata passed to the
partial_fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed topartial_fit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it topartial_fit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') EEGClassifier [source]#
Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
braindecode.eegneuralnet module#
braindecode.regressor module#
- class braindecode.regressor.EEGRegressor(module, *args, cropped=False, callbacks=None, iterator_train__shuffle=True, iterator_train__drop_last=True, aggregate_predictions=True, **kwargs)[source]#
Bases:
_EEGNeuralNet
,NeuralNetRegressor
Regressor that calls loss function directly.
- Parameters:
module (str or torch Module (class or instance)) – A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.criterion (torch criterion (class, default=torch.nn.MSELoss)) – Mean squared error loss.
optimizer (torch optim (class, default=torch.optim.SGD)) – The uninitialized optimizer (update rule) used to optimize the module
lr (float (default=0.01)) – Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.max_epochs (int (default=10)) – The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.batch_size (int (default=128)) – Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.iterator_train (torch DataLoader) – The default PyTorch
DataLoader
used for training data.iterator_valid (torch DataLoader) – The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.dataset (torch Dataset (default=skorch.dataset.Dataset)) – The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.train_split (None or callable (default=skorch.dataset.ValidSplit(5))) –
If
None
, there is no train/validation split. Else,train_split
should be a function or callable that is called with X and y data and should return the tupledataset_train, dataset_valid
. The validation data may beNone
.If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks, not even the default callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).predict_nonlinearity (callable, None, or 'auto' (default='auto')) –
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.warm_start (bool (default=False)) – Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
verbose (int (default=1)) – This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
device (str, torch.device, or None (default='cpu')) – The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
compile (bool (default=False)) – If set to
True
, compile all modules usingtorch.compile
. For this to work, the installed torch version has to supporttorch.compile
. Compiled modules should work identically to non-compiled modules but should run faster on new GPU architectures (Volta and Ampere for instance). Additional arguments fortorch.compile
can be passed using the dunder notation, e.g. when initializing the net withcompile__dynamic=True
,torch.compile
will be called withdynamic=True
.use_caching (bool or 'auto' (default='auto')) – Optionally override the caching behavior of scoring callbacks. Callbacks such as
EpochScoring
andBatchScoring
allow to cache the inference call to save time when calculating scores during training at the expense of memory. In certain situations, e.g. when memory is tight, you may want to disable caching. As it is cumbersome to change the setting on each callback individually, this parameter allows to override their behavior globally. By default ('auto'
), the callbacks will determine if caching is used or not. If this argument is set toFalse
, caching will be disabled on all callbacks. If set toTrue
, caching will be enabled on all callbacks. Implementation note: It is the job of the callbacks to honor this setting.torch_load_kwargs (dict or None (default=None)) –
Additional arguments that will be passed to torch.load when load pickled parameters.
In particular, this is important to because PyTorch will switch (probably in version 2.6.0) to only allow weights to be loaded for security reasons (i.e weights_only switches from False to True). As a consequence, loading pickled parameters may raise an error after upgrading torch because some types are used that are considered insecure. In skorch, we will also make that switch at the same time. To resolve the error, follow the instructions in the torch error message to designate the offending types as secure. Only do this if you trust the source of the file.
If you want to keep loading non-weight types the same way as before, please pass:
torch_load_kwargs={‘weights_only’: False}
You should be aware that this is considered insecure and should only be used if you trust the source of the file. However, this does not introduce new insecurities, it rather corresponds to the status quo from before torch made the switch.
Another way to avoid this issue is to pass use_safetensors=True when calling save_params and load_params. This avoid using pickle in favor of the safetensors format, which is secure by design.
module – Either the name of one of the braindecode models (see
braindecode.models.util.models_dict
) or directly a PyTorch module. When passing directly a torch module, uninstantiated class should be preferred, although instantiated modules will also work.cropped (bool (default=False)) – Defines whether torch model passed to this class is cropped or not. Currently used for callbacks definition.
callbacks (None or list of strings or list of Callback instances (default=None)) – More callbacks, in addition to those returned by
get_default_callbacks
. Each callback should inherit fromskorch.callbacks.Callback
. If notNone
, callbacks can be a list of strings specifying sklearn scoring functions (for scoring functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter) or a list of callbacks where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).iterator_train__shuffle (bool (default=True)) – Defines whether train dataset will be shuffled. As skorch does not shuffle the train dataset by default this one overwrites this option.
aggregate_predictions (bool (default=True)) – Whether to average cropped predictions to obtain window predictions. Used only in the cropped mode.
- prefixes_#
Contains the prefixes to special parameters. E.g., since there is the
'optimizer'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
. Some prefixes are populated dynamically, based on what modules and criteria are defined.
- cuda_dependent_attributes_#
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.
- module_#
The instantiated module.
- Type:
torch module (instance)
- criterion_#
The instantiated criterion.
- Type:
torch criterion (instance)
- callbacks_#
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- Type:
list of tuples
- _modules#
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria#
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers#
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- doc = "Regressor that calls loss function directly.\n\n Parameters\n ----------\n module: str or torch Module (class or instance)\n Either the name of one of the braindecode models (see\n :obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.\n When passing directly a torch module, uninstantiated class should be preferred,\n although instantiated modules will also work.\n\n cropped: bool (default=False)\n Defines whether torch model passed to this class is cropped or not.\n Currently used for callbacks definition.\n\n callbacks: None or list of strings or list of Callback instances (default=None)\n More callbacks, in addition to those returned by\n ``get_default_callbacks``. Each callback should inherit from\n :class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a\n list of strings specifying `sklearn` scoring functions (for scoring\n functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)\n or a list of callbacks where the callback names are inferred from the\n class name. Name conflicts are resolved by appending a count suffix\n starting with 1, e.g. ``EpochScoring_1``. Alternatively,\n a tuple ``(name, callback)`` can be passed, where ``name``\n should be unique. Callbacks may or may not be instantiated.\n The callback name can be used to set parameters on specific\n callbacks (e.g., for the callback with name ``'print_log'``, use\n ``net.set_params(callbacks__print_log__keys_ignored=['epoch',\n 'train_loss'])``).\n\n iterator_train__shuffle: bool (default=True)\n Defines whether train dataset will be shuffled. As skorch does not\n shuffle the train dataset by default this one overwrites this option.\n\n aggregate_predictions: bool (default=True)\n Whether to average cropped predictions to obtain window predictions. Used only in the\n cropped mode.\n\n "#
- fit(X, y=None, **kwargs)[source]#
Initialize and fit the module.
If the module was already initialized, by calling fit, the module will be re-initialized (unless
warm_start
is True). If possible, signal-related parameters are inferred from the data and passed to the module at initialisation. Depending on the type of input passed, the following parameters are inferred:mne.Epochs:
n_times
,n_chans
,n_outputs
,chs_info
,sfreq
,input_window_seconds
numpy array:
n_times
,n_chans
,n_outputs
WindowsDataset with
targets_from='metadata'
(or BaseConcatDataset of such datasets):n_times
,n_chans
,n_outputs
other Dataset:
n_times
,n_chans
other types: no parameters are inferred.
- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
mne.Epochs
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.y (target data, compatible with skorch.dataset.Dataset) – The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.**fit_params (dict) – Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
- get_iterator(dataset, training=False, drop_index=True)[source]#
Get an iterator that allows to loop over the batches of the given data.
If
self.iterator_train__batch_size
and/orself.iterator_test__batch_size
are not set, useself.batch_size
instead.- Parameters:
dataset (torch Dataset (default=skorch.dataset.Dataset)) – Usually,
self.dataset
, initialized with the corresponding data, is passed toget_iterator
.training (bool (default=False)) – Whether to use
iterator_train
oriterator_test
.
- Returns:
An instantiated iterator that allows to loop over the mini-batches.
- Return type:
iterator
- predict_proba(X)[source]#
Return the output of the module’s forward method as a numpy array. In case of cropped decoding returns averaged values for each trial.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant or module’s output for each crop is needed, consider using
forward()
instead.- Parameters:
X (input data, compatible with skorch.dataset.Dataset) –
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Warning
Regressors predict regression targets, so output of this method can’t be interpreted as probabilities. We advise you to use predict method instead of predict_proba.
- Returns:
y_proba
- Return type:
numpy ndarray
- predict_trials(X, return_targets=True)[source]#
Create trialwise predictions and optionally also return trialwise labels from cropped dataset.
- Parameters:
X (braindecode.datasets.BaseConcatDataset) – A braindecode dataset to be predicted.
return_targets (bool) – If True, additionally returns the trial targets.
- Returns:
trial_predictions (np.ndarray) – 3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
trial_labels (np.ndarray) – 2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.
- set_partial_fit_request(*, classes: bool | None | str = '$UNCHANGED$') EEGRegressor [source]#
Request metadata passed to the
partial_fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed topartial_fit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it topartial_fit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') EEGRegressor [source]#
Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.
braindecode.util module#
- braindecode.util.corr(a, b)[source]#
Computes correlation only between terms of a and terms of b, not within a and b.
- Parameters:
a (2darray, features x samples)
b (2darray, features x samples)
- Return type:
Correlation between features in x and features in y
- braindecode.util.cov(a, b)[source]#
Computes covariance only between terms of a and terms of b, not within a and b.
- Parameters:
a (2darray, features x samples)
b (2darray, features x samples)
- Return type:
Covariance between features in x and features in y
- braindecode.util.create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True, description=None, savedir=None, save_format='fif', overwrite=True, random_state=None)[source]#
Create an mne.io.RawArray with fake data, and optionally save it.
This will overwrite already existing files.
- Parameters:
n_channels (int) – Number of channels.
n_times (int) – Number of samples.
sfreq (float) – Sampling frequency.
include_anns (bool) – If True, also create annotations.
description (list | None) – List of descriptions used for creating annotations. It should contain 10 elements.
savedir (str | None) – If provided as a string, the file will be saved under that directory.
save_format (str | list) – If savedir is provided, this specifies the file format the data should be saved to. Can be ‘raw’ or ‘hdf5’, or a list containing both.
random_state (int | RandomState) – Random state for the generation of random data.
- Returns:
raw (mne.io.Raw) – The created Raw object.
save_fname (dict | None) – Dictionary containing the name the raw data was saved to.
- braindecode.util.get_balanced_batches(n_trials, rng, shuffle, n_batches=None, batch_size=None)[source]#
Create indices for batches balanced in size (batches will have maximum size difference of 1). Supply either batch size or number of batches. Resulting batches will not have the given batch size but rather the next largest batch size that allows to split the set into balanced batches (maximum size difference 1).
- braindecode.util.np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs)[source]#
Convenience function to transform numpy array to torch.Tensor.
Converts X to ndarray using asarray if necessary.
- braindecode.util.read_all_file_names(directory, extension)[source]#
Read all files with specified extension from given path and sorts them based on a given sorting key.
- braindecode.util.set_random_seeds(seed, cuda, cudnn_benchmark=None)[source]#
Set seeds for python random module numpy.random and torch.
For more details about reproducibility in pytorch see https://pytorch.org/docs/stable/notes/randomness.html
- Parameters:
seed (int) – Random seed.
cuda (bool) – Whether to set cuda seed with torch.
cudnn_benchmark (bool (default=None)) – Whether pytorch will use cudnn benchmark. When set to None it will not modify torch.backends.cudnn.benchmark (displays warning in the case of possible lack of reproducibility). When set to True, results may not be reproducible (no warning displayed). When set to False it may slow down computations.
Notes
In some cases setting environment variable PYTHONHASHSEED may be needed before running a script to ensure full reproducibility. See https://forums.fast.ai/t/solved-reproducibility-where-is-the-randomness-coming-in/31628/14
Using this function may not ensure full reproducibility of the results as we do not set torch.use_deterministic_algorithms(True).
- braindecode.util.th_to_np(var: Tensor)[source]#
Convenience function to transform torch.Tensor to numpy array.
Should work both for CPU and GPU.