Eyes Open vs. Closed Classification#

Estimated reading time:6 minutes

EEGDash example for eyes open vs. closed classification.

This example uses the eegdash library in combination with PyTorch to develop a deep learning model for analyzing EEG data, specifically for eyes open vs. closed classification in a single subject.

  1. Data Retrieval Using EEGDash: An instance of eegdash.api.EEGDashDataset is created to search and retrieve an EEG dataset. At this step, only the metadata is transferred.

  2. Data Preprocessing Using BrainDecode: This process preprocesses EEG data using Braindecode by reannotating events, selecting specific channels, resampling, filtering, and extracting 2-second epochs, ensuring balanced eyes-open and eyes-closed data for analysis.

  3. Creating train and testing sets: The dataset is split into training (80%) and testing (20%) sets with balanced labels, converted into PyTorch tensors, and wrapped in DataLoader objects for efficient mini-batch training.

  4. Model Definition: The model is a shallow convolutional neural network (ShallowFBCSPNet) with 24 input channels (EEG channels), 2 output classes (eyes-open and eyes-closed).

  5. Model Training and Evaluation Process: This section trains the neural network, normalizes input data, computes cross-entropy loss, updates model parameters, and evaluates classification accuracy over six epochs.

Data Retrieval Using EEGDash#

This section instantiates eegdash.api.EEGDashDataset to fetch the metadata for the experiment before requesting any recordings.

First we find one resting state dataset. This dataset contains both eyes open and eyes closed data.

from pathlib import Path

cache_folder = Path.home() / "eegdash"
from eegdash import EEGDashDataset

ds_eoec = EEGDashDataset(
    query={"dataset": "ds005514", "task": "RestingState", "subject": "NDARDB033FW5"},
    cache_dir=cache_folder,
)
[10/13/25 00:28:51] WARNING  Cache directory does not exist, creating api.py:726
                             it: /home/runner/eegdash
╭────────────────────── EEG 2025 Competition Data Notice ──────────────────────╮
│ This notice is only for users who are participating in the EEG 2025          │
│ Competition.                                                                 │
│                                                                              │
│ EEG 2025 Competition Data Notice!                                            │
│ You are loading one of the datasets that is used in competition, but via     │
│ `EEGDashDataset`.                                                            │
│                                                                              │
│ IMPORTANT:                                                                   │
│ If you download data from `EEGDashDataset`, it is NOT identical to the       │
│ official                                                                     │
│ competition data, which is accessed via `EEGChallengeDataset`. The           │
│ competition data has been downsampled and filtered.                          │
│                                                                              │
│ If you are participating in the competition,                                 │
│ you must use the `EEGChallengeDataset` object to ensure consistency.         │
│                                                                              │
│ If you are not participating in the competition, you can ignore this         │
│ message.                                                                     │
╰─────────────────────────── Source: EEGDashDataset ───────────────────────────╯

Data Preprocessing Using Braindecode#

braindecode is a specialized library for preprocessing EEG and MEG data. In this dataset, there are two key events in the continuous data: instructed_toCloseEyes, marking the start of a 40-second eyes-closed period, and instructed_toOpenEyes, indicating the start of a 20-second eyes-open period.

For the eyes-closed event, we extract 14 seconds of data from 15 to 29 seconds after the event onset. Similarly, for the eyes-open event, we extract data from 5 to 19 seconds after the event onset. This ensures an equal amount of data for both conditions. The event extraction is handled by the custom function eegdash.hbn.preprocessing.hbn_ec_ec_reannotation().

Next, we apply four preprocessing steps in Braindecode: 1. Reannotation of event markers using eegdash.hbn.preprocessing.hbn_ec_ec_reannotation(). 2. Selection of 24 specific EEG channels from the original 128. 3. Resampling the EEG data to a frequency of 128 Hz. 4. Filtering the EEG signals to retain frequencies between 1 Hz and 55 Hz.

When calling the preprocess function, the data is retrieved from the remote repository.

Finally, we use create_windows_from_events to extract 2-second epochs from the data. These epochs serve as the dataset samples. At this stage, each sample is automatically labeled with the corresponding event type (eyes-open or eyes-closed). windows_ds is a PyTorch dataset, and when queried, it returns labels for eyes-open and eyes-closed (assigned as labels 0 and 1, corresponding to their respective event markers).

from braindecode.preprocessing import (
    preprocess,
    Preprocessor,
    create_windows_from_events,
)
import numpy as np
from eegdash.hbn.preprocessing import hbn_ec_ec_reannotation
import warnings

warnings.simplefilter("ignore", category=RuntimeWarning)


# BrainDecode preprocessors
preprocessors = [
    hbn_ec_ec_reannotation(),
    Preprocessor(
        "pick_channels",
        ch_names=[
            "E22",
            "E9",
            "E33",
            "E24",
            "E11",
            "E124",
            "E122",
            "E29",
            "E6",
            "E111",
            "E45",
            "E36",
            "E104",
            "E108",
            "E42",
            "E55",
            "E93",
            "E58",
            "E52",
            "E62",
            "E92",
            "E96",
            "E70",
            "Cz",
        ],
    ),
    Preprocessor("resample", sfreq=128),
    Preprocessor("filter", l_freq=1, h_freq=55),
]
preprocess(ds_eoec, preprocessors)

# Extract 2-second segments
windows_ds = create_windows_from_events(
    ds_eoec,
    trial_start_offset_samples=0,
    trial_stop_offset_samples=256,
    preload=True,
)
Downloading dataset_description.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading dataset_description.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.6B/s]

Downloading participants.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading participants.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 5.03B/s]

Downloading sub-NDARDB033FW5_task-RestingState_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-RestingState_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 7.63B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-1_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-1_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.5B/s]

Downloading sub-NDARDB033FW5_task-surroundSupp_run-2_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-surroundSupp_run-2_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.86B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-2_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-2_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.6B/s]

Downloading sub-NDARDB033FW5_task-DespicableMe_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-DespicableMe_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 13.1B/s]

Downloading sub-NDARDB033FW5_task-seqLearning8target_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-seqLearning8target_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.04B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-3_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-3_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 10.4B/s]

Downloading sub-NDARDB033FW5_task-DiaryOfAWimpyKid_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-DiaryOfAWimpyKid_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.8B/s]

Downloading sub-NDARDB033FW5_task-ThePresent_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-ThePresent_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.9B/s]

Downloading sub-NDARDB033FW5_task-surroundSupp_run-1_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-surroundSupp_run-1_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 8.63B/s]

Downloading sub-NDARDB033FW5_task-FunwithFractals_events.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-FunwithFractals_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 13.0B/s]

Downloading task-seqLearning8target_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning8target_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.6B/s]

Downloading task-RestingState_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-RestingState_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.9B/s]

Downloading task-surroundSupp_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-surroundSupp_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.62B/s]

Downloading task-FunwithFractals_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-FunwithFractals_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 10.1B/s]

Downloading task-symbolSearch_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-symbolSearch_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.0B/s]

Downloading task-ThePresent_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-ThePresent_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.7B/s]

Downloading task-seqLearning6target_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning6target_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.30B/s]

Downloading task-contrastChangeDetection_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-contrastChangeDetection_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 7.00B/s]

Downloading task-DespicableMe_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DespicableMe_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.9B/s]

Downloading task-DiaryOfAWimpyKid_events.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DiaryOfAWimpyKid_events.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 7.55B/s]

Downloading sub-NDARDB033FW5_task-ThePresent_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-ThePresent_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.3B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-3_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-3_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 10.4B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-1_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-1_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.05B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-2_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-2_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.2B/s]

Downloading sub-NDARDB033FW5_task-DiaryOfAWimpyKid_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-DiaryOfAWimpyKid_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 8.12B/s]

Downloading sub-NDARDB033FW5_task-seqLearning8target_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-seqLearning8target_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.2B/s]

Downloading sub-NDARDB033FW5_task-RestingState_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-RestingState_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.9B/s]

Downloading sub-NDARDB033FW5_task-surroundSupp_run-1_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-surroundSupp_run-1_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.95B/s]

Downloading sub-NDARDB033FW5_task-surroundSupp_run-2_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-surroundSupp_run-2_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.32B/s]

Downloading sub-NDARDB033FW5_task-FunwithFractals_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-FunwithFractals_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.1B/s]

Downloading sub-NDARDB033FW5_task-DespicableMe_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-DespicableMe_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.9B/s]

Downloading task-symbolSearch_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-symbolSearch_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 8.74B/s]

Downloading task-surroundSupp_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-surroundSupp_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.2B/s]

Downloading task-contrastChangeDetection_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-contrastChangeDetection_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.8B/s]

Downloading task-seqLearning8target_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning8target_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.06B/s]

Downloading task-DespicableMe_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DespicableMe_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.4B/s]

Downloading task-ThePresent_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-ThePresent_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.09B/s]

Downloading task-FunwithFractals_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-FunwithFractals_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.7B/s]

Downloading task-RestingState_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-RestingState_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.2B/s]

Downloading task-seqLearning6target_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning6target_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 8.62B/s]

Downloading task-DiaryOfAWimpyKid_eeg.json:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DiaryOfAWimpyKid_eeg.json: 100%|██████████| 1.00/1.00 [00:00<00:00, 10.4B/s]

Downloading sub-NDARDB033FW5_task-ThePresent_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-ThePresent_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 6.90B/s]

Downloading sub-NDARDB033FW5_task-FunwithFractals_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-FunwithFractals_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.6B/s]

Downloading sub-NDARDB033FW5_task-DiaryOfAWimpyKid_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-DiaryOfAWimpyKid_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 12.1B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-3_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-3_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 8.30B/s]

Downloading sub-NDARDB033FW5_task-surroundSupp_run-2_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-surroundSupp_run-2_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.7B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-1_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-1_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.6B/s]

Downloading sub-NDARDB033FW5_task-RestingState_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-RestingState_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 9.82B/s]

Downloading sub-NDARDB033FW5_task-seqLearning8target_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-seqLearning8target_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 8.84B/s]

Downloading sub-NDARDB033FW5_task-surroundSupp_run-1_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-surroundSupp_run-1_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 4.78B/s]
Downloading sub-NDARDB033FW5_task-surroundSupp_run-1_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 4.77B/s]

Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-2_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-contrastChangeDetection_run-2_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 11.9B/s]

Downloading sub-NDARDB033FW5_task-DespicableMe_channels.tsv:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-DespicableMe_channels.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 10.4B/s]

Downloading sub-NDARDB033FW5_task-RestingState_eeg.set:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-RestingState_eeg.set: 100%|██████████| 1.00/1.00 [00:02<00:00, 2.32s/B]
Downloading sub-NDARDB033FW5_task-RestingState_eeg.set: 100%|██████████| 1.00/1.00 [00:02<00:00, 2.32s/B]

Downloading sub-NDARDB033FW5_task-RestingState_eeg.set:   0%|          | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARDB033FW5_task-RestingState_eeg.set: 100%|██████████| 1.00/1.00 [00:01<00:00, 1.89s/B]
Downloading sub-NDARDB033FW5_task-RestingState_eeg.set: 100%|██████████| 1.00/1.00 [00:01<00:00, 1.89s/B]
Used Annotations descriptions: [np.str_('boundary'), np.str_('break cnt'), np.str_('instructed_toCloseEyes'), np.str_('instructed_toOpenEyes'), np.str_('resting_start')]
[10/13/25 00:29:12] INFO     Original events found with ids: preprocessing.py:66
                             {np.str_('boundary'): 1,
                             np.str_('break cnt'): 2,
                             np.str_('instructed_toCloseEyes
                             '): 3,
                             np.str_('instructed_toOpenEyes'
                             ): 4, np.str_('resting_start'):
                             5}
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 55 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 55.00 Hz
- Upper transition bandwidth: 9.00 Hz (-6 dB cutoff frequency: 59.50 Hz)
- Filter length: 423 samples (3.305 s)

Used Annotations descriptions: [np.str_('eyes_closed'), np.str_('eyes_open')]

Plotting a Single Channel for One Sample#

It’s always a good practice to verify that the data has been properly loaded and processed. Here, we plot a single channel from one sample to ensure the signal is present and looks as expected.

import matplotlib.pyplot as plt

plt.figure()
plt.plot(windows_ds[2][0][0, :].transpose())  # first channel of first epoch
plt.show()
tutorial eoec

Creating training and test sets#

The code below creates a training and test set. We first split the data into training and test sets using the train_test_split function from the sklearn library. We then create a TensorDataset for the training and test sets.

  1. Set Random Seed – The random seed is fixed using torch.manual_seed(random_state) to ensure reproducibility in dataset splitting and model training.

  2. Extract Labels from the Dataset – Labels (eye-open or eye-closed events) are extracted from windows_ds, stored as a NumPy array, and printed for verification.

  3. Split Dataset into Train and Test Sets – The dataset is split into training (80%) and testing (20%) subsets using train_test_split(), ensuring balanced stratification based on the extracted labels.

  4. Convert Data to PyTorch Tensors – The selected training and testing samples are converted into FloatTensor for input features and LongTensor for labels, making them compatible with PyTorch models.

  5. Create DataLoaders – The datasets are wrapped in PyTorch DataLoader objects with a batch size of 10, enabling efficient mini-batch training and shuffling.

import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

# Set random seed for reproducibility
random_state = 42
torch.manual_seed(random_state)
np.random.seed(random_state)

# Extract labels from the dataset
eo_ec = np.array([ds[1] for ds in windows_ds]).transpose()  # check labels
print("labels: ", eo_ec)

# Get balanced indices for male and female subjects
train_indices, test_indices = train_test_split(
    range(len(windows_ds)), test_size=0.2, stratify=eo_ec, random_state=random_state
)

# Convert the data to tensors
X_train = torch.FloatTensor(
    np.array([windows_ds[i][0] for i in train_indices])
)  # Convert list of arrays to single tensor
X_test = torch.FloatTensor(
    np.array([windows_ds[i][0] for i in test_indices])
)  # Convert list of arrays to single tensor
y_train = torch.LongTensor(eo_ec[train_indices])  # Convert targets to tensor
y_test = torch.LongTensor(eo_ec[test_indices])  # Convert targets to tensor
dataset_train = TensorDataset(X_train, y_train)
dataset_test = TensorDataset(X_test, y_test)

# Create data loaders for training and testing (batch size 10)
train_loader = DataLoader(dataset_train, batch_size=10, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=10, shuffle=True)

# Print shapes and sizes to verify split
print(
    f"Shape of data {X_train.shape} number of samples - Train: {len(train_loader)}, Test: {len(test_loader)}"
)
print(
    f"Eyes-Open/Eyes-Closed balance, train: {np.mean(eo_ec[train_indices]):.2f}, test: {np.mean(eo_ec[test_indices]):.2f}"
)
labels:  [1 1 1 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0
 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0]
Shape of data torch.Size([56, 24, 256]) number of samples - Train: 6, Test: 2
Eyes-Open/Eyes-Closed balance, train: 0.50, test: 0.50

Check labels#

It is good practice to verify the labels and ensure the random seed is functioning correctly. If all labels are 0s (eyes closed) or 1s (eyes open), it could indicate an issue with data loading or stratification, requiring further investigation.

Visualize a batch of target labels

dataiter = iter(train_loader)
first_item, label = dataiter.__next__()
label
tensor([0, 1, 1, 1, 1, 0, 1, 1, 0, 0])

Create model#

The model is a shallow convolutional neural network (ShallowFBCSPNet) with 24 input channels (EEG channels), 2 output classes (eyes-open and eyes-closed), and an input window size of 256 samples (2 seconds of EEG data).

import torch
import numpy as np
from torch.nn import functional as F
from braindecode.models import ShallowFBCSPNet
from torchinfo import summary

torch.manual_seed(random_state)
model = ShallowFBCSPNet(24, 2, n_times=256, final_conv_length="auto")
summary(model, input_size=(1, 24, 256))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ShallowFBCSPNet                          [1, 2]                    --
├─Ensure4d: 1-1                          [1, 24, 256, 1]           --
├─Rearrange: 1-2                         [1, 1, 256, 24]           --
├─CombinedConv: 1-3                      [1, 40, 232, 1]           39,440
├─BatchNorm2d: 1-4                       [1, 40, 232, 1]           80
├─Expression: 1-5                        [1, 40, 232, 1]           --
├─AvgPool2d: 1-6                         [1, 40, 11, 1]            --
├─SafeLog: 1-7                           [1, 40, 11, 1]            --
├─Dropout: 1-8                           [1, 40, 11, 1]            --
├─Sequential: 1-9                        [1, 2]                    --
│    └─Conv2d: 2-1                       [1, 2, 1, 1]              882
│    └─SqueezeFinalOutput: 2-2           [1, 2]                    --
│    │    └─Rearrange: 3-1               [1, 2, 1]                 --
==========================================================================================
Total params: 40,402
Trainable params: 40,402
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.00
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.07
Params size (MB): 0.00
Estimated Total Size (MB): 0.10
==========================================================================================

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ShallowFBCSPNet                          [1, 2]                    --
├─Ensure4d: 1-1                          [1, 24, 256, 1]           --
├─Rearrange: 1-2                         [1, 1, 256, 24]           --
├─CombinedConv: 1-3                      [1, 40, 232, 1]           39,440
├─BatchNorm2d: 1-4                       [1, 40, 232, 1]           80
├─Expression: 1-5                        [1, 40, 232, 1]           --
├─AvgPool2d: 1-6                         [1, 40, 11, 1]            --
├─SafeLog: 1-7                           [1, 40, 11, 1]            --
├─Dropout: 1-8                           [1, 40, 11, 1]            --
├─Sequential: 1-9                        [1, 2]                    --
│    └─Conv2d: 2-1                       [1, 2, 1, 1]              882
│    └─SqueezeFinalOutput: 2-2           [1, 2]                    --
│    │    └─Rearrange: 3-1               [1, 2, 1]                 --
==========================================================================================
Total params: 40,402
Trainable params: 40,402
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.00
==========================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.07
Params size (MB): 0.00
Estimated Total Size (MB): 0.10
==========================================================================================

Model Training and Evaluation Process#

This section trains the neural network using the Adamax optimizer, normalizes input data, computes cross-entropy loss, updates model parameters, and tracks accuracy across six epochs.

  1. Set Up Optimizer and Learning Rate Scheduler – The Adamax optimizer initializes with a learning rate of 0.002 and weight decay of 0.001 for regularization. An ExponentialLR scheduler with a decay factor of 1 keeps the learning rate constant.

  2. Allocate Model to Device – The model moves to the specified device (CPU, GPU, or MPS for Mac silicon) to optimize computation efficiency.

  3. Normalize Input Data – The normalize_data function standardizes input data by subtracting the mean and dividing by the standard deviation along the time dimension before transferring it to the appropriate device.

  4. Evaluates Classification Accuracy Over Six Epochs – The training loop iterates through data batches with the model in training mode. It normalizes inputs, computes predictions, calculates cross-entropy loss, performs backpropagation, updates model parameters, and steps the learning rate scheduler. It tracks correct predictions to compute accuracy.

  5. Evaluate on Test Data – After each epoch, the model runs in evaluation mode on the test set. It computes predictions on normalized data and calculates test accuracy by comparing outputs with actual labels.

optimizer = torch.optim.Adamax(model.parameters(), lr=0.002, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
model = model.to(device=device)  # move the model parameters to CPU/GPU
epochs = 6


def normalize_data(x):
    mean = x.mean(dim=2, keepdim=True)
    std = x.std(dim=2, keepdim=True) + 1e-7  # add small epsilon for numerical stability
    x = (x - mean) / std
    x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
    return x


for e in range(epochs):
    # training
    correct_train = 0
    for t, (x, y) in enumerate(train_loader):
        model.train()  # put model to training mode
        scores = model(normalize_data(x))
        y = y.to(device=device, dtype=torch.long)
        _, preds = scores.max(1)
        correct_train += (preds == y).sum() / len(dataset_train)

        loss = F.cross_entropy(scores, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    # Validation
    correct_test = 0
    for t, (x, y) in enumerate(test_loader):
        model.eval()  # put model to testing mode
        scores = model(normalize_data(x))
        y = y.to(device=device, dtype=torch.long)
        _, preds = scores.max(1)
        correct_test += (preds == y).sum() / len(dataset_test)

    # Reporting
    print(
        f"Epoch {e}, Train accuracy: {correct_train:.2f}, Test accuracy: {correct_test:.2f}"
    )
Epoch 0, Train accuracy: 0.66, Test accuracy: 0.50
Epoch 1, Train accuracy: 0.79, Test accuracy: 0.50
Epoch 2, Train accuracy: 0.91, Test accuracy: 0.50
Epoch 3, Train accuracy: 0.88, Test accuracy: 0.57
Epoch 4, Train accuracy: 0.91, Test accuracy: 0.57
Epoch 5, Train accuracy: 0.88, Test accuracy: 0.50

Total running time of the script: (0 minutes 23.357 seconds)

Estimated memory usage: 1316 MB

Gallery generated by Sphinx-Gallery