.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/plot_tuh_discrete_multitarget.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_plot_tuh_discrete_multitarget.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_plot_tuh_discrete_multitarget.py:


Multiple discrete targets with the TUH EEG Corpus
=================================================

In this example, we showcase usage of multiple discrete targets per recording
with the TUH EEG Corpus.

.. GENERATED FROM PYTHON SOURCE LINES 8-22

.. code-block:: default


    # Author: Lukas Gemein <l.gemein@gmail.com>
    #
    # License: BSD (3-clause)

    import mne
    from torch.utils.data import DataLoader

    from braindecode.datasets import TUH
    from braindecode.preprocessing import create_fixed_length_windows

    mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted









.. GENERATED FROM PYTHON SOURCE LINES 23-26

If you want to try this code with the actual data, please delete the next
section. We are required to mock some dataset functionality, since the data
is not available at creation time of this example.

.. GENERATED FROM PYTHON SOURCE LINES 26-29

.. code-block:: default

    from braindecode.datasets.tuh import _TUHMock as TUH  # noqa F811









.. GENERATED FROM PYTHON SOURCE LINES 30-33

We start by creating a TUH dataset. Instead of just a str, we give it
multiple strings as target names. Each of the strings has to exist as a
column in the description DataFrame.

.. GENERATED FROM PYTHON SOURCE LINES 33-45

.. code-block:: default


    TUH_PATH = 'please insert actual path to data here'
    tuh = TUH(
        path=TUH_PATH,
        recording_ids=None,
        target_name=('age', 'gender'),  # use both age and gender as decoding target
        preload=False,
        add_physician_reports=False,
    )
    tuh.description







.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <div>
    <style scoped>
        .dataframe tbody tr th:only-of-type {
            vertical-align: middle;
        }

        .dataframe tbody tr th {
            vertical-align: top;
        }

        .dataframe thead th {
            text-align: right;
        }
    </style>
    <table border="1" class="dataframe">
      <thead>
        <tr style="text-align: right;">
          <th></th>
          <th>path</th>
          <th>version</th>
          <th>year</th>
          <th>month</th>
          <th>day</th>
          <th>subject</th>
          <th>session</th>
          <th>segment</th>
          <th>age</th>
          <th>gender</th>
        </tr>
      </thead>
      <tbody>
        <tr>
          <th>0</th>
          <td>tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001...</td>
          <td>v1.1.0</td>
          <td>2003</td>
          <td>2</td>
          <td>5</td>
          <td>58</td>
          <td>1</td>
          <td>0</td>
          <td>0</td>
          <td>M</td>
        </tr>
        <tr>
          <th>1</th>
          <td>tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004...</td>
          <td>v1.1.0</td>
          <td>2014</td>
          <td>9</td>
          <td>30</td>
          <td>9932</td>
          <td>4</td>
          <td>13</td>
          <td>53</td>
          <td>F</td>
        </tr>
        <tr>
          <th>2</th>
          <td>tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s0...</td>
          <td>v1.1.0</td>
          <td>2014</td>
          <td>12</td>
          <td>14</td>
          <td>12331</td>
          <td>3</td>
          <td>2</td>
          <td>39</td>
          <td>M</td>
        </tr>
        <tr>
          <th>3</th>
          <td>tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001...</td>
          <td>v1.1.0</td>
          <td>2015</td>
          <td>12</td>
          <td>30</td>
          <td>0</td>
          <td>1</td>
          <td>0</td>
          <td>37</td>
          <td>M</td>
        </tr>
        <tr>
          <th>4</th>
          <td>tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s0...</td>
          <td>v1.2.0</td>
          <td>2016</td>
          <td>1</td>
          <td>15</td>
          <td>14928</td>
          <td>4</td>
          <td>7</td>
          <td>83</td>
          <td>F</td>
        </tr>
      </tbody>
    </table>
    </div>
    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 46-49

Iterating through the dataset gives x as ndarray(n_channels x 1) as well as
the target as [age of the subject, gender of the subject]. Let's look at the last example
as it has more interesting age/gender labels (compare to the last row of the dataframe above).

.. GENERATED FROM PYTHON SOURCE LINES 49-54

.. code-block:: default

    x, y = tuh[-1]
    print('x:', x)
    print('y:', y)






.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    x: [[-0.48388163]
     [-1.1033349 ]
     [-0.00548946]
     [-0.69145748]
     [-0.72950636]
     [-0.6732013 ]
     [-0.02884033]
     [-0.09684461]
     [ 0.66150905]
     [ 1.35850294]
     [-1.54706468]
     [ 0.81112458]
     [ 0.48616393]
     [ 0.26901556]
     [ 1.02706921]
     [-0.46342266]
     [-0.43525863]
     [-1.02658337]
     [-0.4584042 ]
     [ 0.45492769]
     [ 1.21383652]]
    y: [83, 'F']




.. GENERATED FROM PYTHON SOURCE LINES 55-59

We will skip preprocessing steps for now, since it is not the aim of this
example. Instead, we will directly create compute windows. We specify a
mapping from genders 'M' and 'F' to integers, since this is required for
decoding.

.. GENERATED FROM PYTHON SOURCE LINES 59-74

.. code-block:: default


    tuh_windows = create_fixed_length_windows(
        tuh,
        start_offset_samples=0,
        stop_offset_samples=None,
        window_size_samples=1000,
        window_stride_samples=1000,
        drop_last_window=False,
        mapping={'M': 0, 'F': 1},  # map non-digit targets
    )
    # store the number of windows required for loading later on
    tuh_windows.set_description({
        "n_windows": [len(d) for d in tuh_windows.datasets]})









.. GENERATED FROM PYTHON SOURCE LINES 75-77

Iterating through the dataset gives x as ndarray(n_channels x 1000), y as
[age, gender], and ind. Let's look at the last example again.

.. GENERATED FROM PYTHON SOURCE LINES 77-83

.. code-block:: default

    x, y, ind = tuh_windows[-1]
    print('x:', x)
    print('y:', y)
    print('ind:', ind)






.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    x: [[ 5.6389427e-01 -2.1618271e+00 -9.9437243e-01 ... -6.4533629e-02
       3.9639103e-01 -4.8388162e-01]
     [ 1.1334016e-04 -2.4711089e-01  2.3326023e-01 ... -5.3718823e-01
       1.1165446e+00 -1.1033349e+00]
     [ 2.5976139e-01 -1.6312467e+00 -5.4536062e-01 ... -6.4550507e-01
      -3.1091759e-01 -5.4894560e-03]
     ...
     [ 2.1103388e-01  2.1207649e-01  1.0596663e+00 ...  1.1248783e+00
       2.2101052e+00 -4.5840421e-01]
     [ 2.6553613e-01 -1.0722766e+00 -1.8160485e+00 ... -4.7655761e-01
      -2.3370227e-02  4.5492768e-01]
     [ 6.8648207e-01  1.2309586e-01  3.9327252e-01 ...  9.7762001e-01
      -4.7603920e-01  1.2138366e+00]]
    y: [83, 1]
    ind: [3, 2600, 3600]




.. GENERATED FROM PYTHON SOURCE LINES 84-86

We give the dataset to a pytorch DataLoader, such that it can be used for
model training.

.. GENERATED FROM PYTHON SOURCE LINES 86-92

.. code-block:: default

    dl = DataLoader(
        dataset=tuh_windows,
        batch_size=4,
    )









.. GENERATED FROM PYTHON SOURCE LINES 93-97

Iterating through the DataLoader gives batch_X as tensor(4 x n_channels x
1000), batch_y as [tensor([4 x age of subject]), tensor([4 x gender of
subject])], and batch_ind. We will iterate to the end to look at the last example
again.

.. GENERATED FROM PYTHON SOURCE LINES 97-102

.. code-block:: default

    for batch_X, batch_y, batch_ind in dl:
        pass
    print('batch_X:', batch_X)
    print('batch_y:', batch_y)
    print('batch_ind:', batch_ind)




.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    batch_X: tensor([[[ 1.9264e-01, -2.8769e-01, -4.0477e-02,  ...,  4.3451e-01,
               2.3285e-01, -3.0400e-01],
             [-5.6241e-01, -2.4511e+00, -1.5853e+00,  ..., -1.4923e+00,
               1.1025e+00,  4.7152e-01],
             [ 4.5288e-01,  2.9770e-01, -7.7068e-03,  ...,  1.6793e-01,
              -5.4024e-01,  2.3311e+00],
             ...,
             [-1.4093e-01,  2.1644e-01, -7.2651e-02,  ..., -2.2531e+00,
              -2.3257e+00, -1.0198e-01],
             [ 1.7482e+00,  6.3536e-01, -1.3564e+00,  ..., -1.0846e-01,
               7.7717e-02,  5.7999e-01],
             [-5.4359e-01, -1.0553e+00,  2.1270e-01,  ...,  8.6473e-01,
              -1.0241e+00, -5.6435e-01]],

            [[ 2.7934e-01, -5.5462e-01, -2.3934e+00,  ..., -6.4195e-01,
               1.2517e+00,  1.4091e+00],
             [ 1.1977e+00,  7.7382e-01, -1.2499e+00,  ..., -5.1294e-01,
               1.3692e+00, -1.0125e+00],
             [-2.1263e+00, -5.8350e-02, -2.3486e-01,  ..., -6.6659e-01,
              -3.5822e-02,  8.5182e-01],
             ...,
             [-1.8836e+00, -5.2328e-01, -1.7144e+00,  ...,  1.9581e+00,
              -3.3173e-01,  5.9458e-01],
             [ 5.3573e-01,  4.7540e-01,  1.8706e+00,  ...,  1.1629e+00,
               7.8696e-01, -1.5714e+00],
             [ 5.6450e-01,  8.2211e-01,  3.2242e-01,  ..., -2.3119e+00,
              -7.1520e-01,  7.7749e-02]],

            [[ 1.4595e+00,  7.5736e-01,  4.0588e-02,  ...,  1.4255e+00,
               6.8046e-01,  5.0423e-01],
             [-8.8447e-01, -1.5425e-01,  7.6564e-01,  ...,  5.5104e-01,
              -8.6491e-01,  7.1067e-01],
             [ 3.9101e-01, -6.7435e-01,  3.1399e-01,  ..., -2.6413e-01,
               6.7261e-01, -4.9560e-01],
             ...,
             [ 1.2032e+00,  3.0923e-01,  4.1398e-01,  ..., -5.7762e-01,
              -4.7420e-02,  4.0071e-01],
             [ 3.6943e-01, -8.9819e-01,  1.0731e+00,  ...,  2.2911e-01,
               2.1890e-01,  2.2932e+00],
             [ 1.0741e+00,  1.6643e+00,  5.2559e-01,  ...,  1.2460e-01,
              -1.6045e+00,  2.4247e+00]],

            [[ 5.6389e-01, -2.1618e+00, -9.9437e-01,  ..., -6.4534e-02,
               3.9639e-01, -4.8388e-01],
             [ 1.1334e-04, -2.4711e-01,  2.3326e-01,  ..., -5.3719e-01,
               1.1165e+00, -1.1033e+00],
             [ 2.5976e-01, -1.6312e+00, -5.4536e-01,  ..., -6.4551e-01,
              -3.1092e-01, -5.4895e-03],
             ...,
             [ 2.1103e-01,  2.1208e-01,  1.0597e+00,  ...,  1.1249e+00,
               2.2101e+00, -4.5840e-01],
             [ 2.6554e-01, -1.0723e+00, -1.8160e+00,  ..., -4.7656e-01,
              -2.3370e-02,  4.5493e-01],
             [ 6.8648e-01,  1.2310e-01,  3.9327e-01,  ...,  9.7762e-01,
              -4.7604e-01,  1.2138e+00]]])
    batch_y: [tensor([83, 83, 83, 83]), tensor([1, 1, 1, 1])]
    batch_ind: [tensor([0, 1, 2, 3]), tensor([   0, 1000, 2000, 2600]), tensor([1000, 2000, 3000, 3600])]





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  1.349 seconds)

**Estimated memory usage:**  19 MB


.. _sphx_glr_download_auto_examples_plot_tuh_discrete_multitarget.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_tuh_discrete_multitarget.py <plot_tuh_discrete_multitarget.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_tuh_discrete_multitarget.ipynb <plot_tuh_discrete_multitarget.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_