braindecode.regressor.EEGRegressor¶
- class braindecode.regressor.EEGRegressor(*args, cropped=False, callbacks=None, iterator_train__shuffle=True, aggregate_predictions=True, **kwargs)¶
Regressor that calls loss function directly.
- Parameters
- module: torch module (class or instance)
A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.- criterion: torch criterion (class, default=torch.nn.MSELoss)
Mean squared error loss.
- optimizer: torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr: float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs: int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size: int (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.- iterator_train: torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid: torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset: torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split: None or callable (default=skorch.dataset.ValidSplit(5))
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None.If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity: callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start: bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose: int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device: str, torch.device (default=’cpu’)
The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
- cropped: bool (default=False)
Defines whether torch model passed to this class is cropped or not. Currently used for callbacks definition.
- callbacks: None or list of strings or list of Callback instances (default=None)
More callbacks, in addition to those returned by
get_default_callbacks
. Each callback should inherit fromskorch.callbacks.Callback
. If notNone
, callbacks can be a list of strings specifying sklearn scoring functions (for scoring functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter) or a list of callbacks where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- iterator_train__shuffle: bool (default=True)
Defines whether train dataset will be shuffled. As skorch does not shuffle the train dataset by default this one overwrites this option.
- aggregate_predictions: bool (default=True)
Whether to average cropped predictions to obtain window predictions. Used only in the cropped mode.
- Attributes
- prefixes_: list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_: list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_: bool
Whether the
NeuralNet
was initialized.- module_: torch module (instance)
The instantiated module.
- criterion_: torch criterion (instance)
The instantiated criterion.
- callbacks_: list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules: list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria: list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers: list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
Methods
- fit(X, y, **kwargs)¶
See
NeuralNet.fit
.In contrast to
NeuralNet.fit
,y
is non-optional to avoid mistakenly forgetting abouty
. However,y
can be set toNone
in case it is derived dynamically fromX
.
- get_iterator(dataset, training=False, drop_index=True)¶
Get an iterator that allows to loop over the batches of the given data.
If
self.iterator_train__batch_size
and/orself.iterator_test__batch_size
are not set, useself.batch_size
instead.- Parameters
- datasettorch Dataset (default=skorch.dataset.Dataset)
Usually,
self.dataset
, initialized with the corresponding data, is passed toget_iterator
.- trainingbool (default=False)
Whether to use
iterator_train
oriterator_test
.
- Returns
- iterator
An instantiated iterator that allows to loop over the mini-batches.
- get_loss(y_pred, y_true, *args, **kwargs)¶
Return the loss for this batch by calling NeuralNet get_loss.
- Parameters
- y_predtorch tensor
Predicted target values
- y_truetorch tensor
True target values.
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- trainingbool (default=False)
Whether train mode should be used or not.
- Returns
- lossfloat
The loss value.
- on_batch_end(net, *batch, training=False, **kwargs)¶
- predict_proba(X)¶
Return the output of the module’s forward method as a numpy array. In case of cropped decoding returns averaged values for each trial.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant or module’s output for each crop is needed, consider using
forward()
instead.- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
- Returns
- y_probanumpy ndarray
Warning
Regressors predict regression targets, so output of this method can’t be interpreted as probabilities. We advise you to use predict method instead of predict_proba.
- predict_trials(X, return_targets=True)¶
Create trialwise predictions and optionally also return trialwise labels from cropped dataset.
- Parameters
- X: braindecode.datasets.BaseConcatDataset
A braindecode dataset to be predicted.
- return_targets: bool
If True, additionally returns the trial targets.
- Returns
- trial_predictions: np.ndarray
3-dimensional array (n_trials x n_classes x n_predictions), where the number of predictions depend on the chosen window size and the receptive field of the network.
- trial_labels: np.ndarray
2-dimensional array (n_trials x n_targets) where the number of targets depends on the decoding paradigm and can be either a single value, multiple values, or a sequence.
- predict_with_window_inds_and_ys(dataset)¶