braindecode.samplers.DistributedRecordingSampler#

class braindecode.samplers.DistributedRecordingSampler(metadata, random_state=None, **kwargs)[source]#

Base sampler simplifying sampling from recordings in distributed setting.

Parameters:
  • metadata (pd.DataFrame) –

    DataFrame with at least one of {subject, session, run} columns for each window in the BaseConcatDataset to sample examples from. Normally obtained with BaseConcatDataset.get_metadata(). For instance, metadata.head() might look like this:

    i_window_in_trial  i_start_in_trial  i_stop_in_trial  target  subject    session    run
    0                  0                 0              500      -1        4  session_T  run_0
    1                  1               500             1000      -1        4  session_T  run_0
    2                  2              1000             1500      -1        4  session_T  run_0
    3                  3              1500             2000      -1        4  session_T  run_0
    4                  4              2000             2500      -1        4  session_T  run_0
    

  • random_state (np.RandomState | int | None) – Random state.

info#

Series with MultiIndex index which contains the subject, session, run and window indices information in an easily accessible structure for quick sampling of windows.

Type:

pd.DataFrame

n_recordings#

Number of recordings available.

Type:

int

kwargs#

Additional keyword arguments to pass to torch DistributedSampler. See https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler

Type:

dict

Methods

sample_recording()[source]#

Return a random recording index. super().__iter__() contains indices of datasets specific to the current process determined by the DistributedSampler

sample_window(rec_ind=None)[source]#

Return a specific window.