Note
Go to the end to download the full example code. or to run this example in your browser via Binder
Challenge 1: Cross-Task Transfer Learning!#
Estimated reading time:2 minutes
# .. image:: https://colab.research.google.com/assets/colab-badge.svg
# :target: https://colab.research.google.com/github/eeg2025/startkit/blob/main/challenge_1.ipynb
# :alt: Open In Colab
# Preliminary notes
# -----------------
# Before we begin, I just want to make a deal with you, ok?
# This is a community competition with a strong open-source foundation.
# When I say open-source, I mean volunteer work.
#
# So, if you see something that does not work or could be improved, first, **please be kind**, and
# we will fix it together on GitHub, okay?
#
# The entire decoding community will only go further when we stop
# solving the same problems over and over again, and it starts working together.
# How can we use the knowledge from one EEG Decoding task into another?
# ---------------------------------------------------------------------
# Transfer learning is a widespread technique used in deep learning. It
# uses knowledge learned from one source task/domain in another target
# task/domain. It has been studied in depth in computer vision, natural
# language processing, and speech, but what about EEG brain decoding?
#
# The cross-task transfer learning scenario in EEG decoding is remarkably
# underexplored compared to the development of new models,
# `Aristimunha et al. (2023) <https://arxiv.org/abs/2308.02408>`__, even
# though it can be much more useful for real applications, see
# `Wimpff et al. (2025) <https://arxiv.org/abs/2502.06828>`__,
# `Wu et al. (2025) <https://arxiv.org/abs/2507.09882>`__.
#
# Our Challenge 1 addresses a key goal in neurotechnology: decoding
# cognitive function from EEG using the pre-trained knowledge from another.
# In other words, developing models that can effectively
# transfer/adapt/adjust/fine-tune knowledge from passive EEG tasks to
# active tasks.
#
# The ability to generalize and transfer is something critical that we
# believe should be focused on. To go beyond just comparing metrics numbers
# that are often not comparable, given the specificities of EEG, such as
# pre-processing, inter-subject variability, and many other unique
# components of this type of data.
#
# This means your submitted model might be trained on a subset of tasks
# and fine-tuned on data from another condition, evaluating its capacity to
# generalize with task-specific fine-tuning.
# __________
#
# Note: For simplicity purposes, we will only show how to do the decoding
# directly in our target task, and it is up to the teams to think about
# how to use the passive task to perform the pre-training.
# Install dependencies
# --------------------
# For the challenge, we will need two significant dependencies:
# `braindecode` and `eegdash`. The libraries will install PyTorch,
# Pytorch Audio, Scikit-learn, MNE, MNE-BIDS, and many other packages
# necessary for the many functions.
#
# Install dependencies on colab or your local machine, as eegdash
# have braindecode as a dependency.
# you can just run ``pip install eegdash``.
# Imports and setup
# -----------------
from pathlib import Path
import torch
from braindecode.datasets import BaseConcatDataset
from braindecode.preprocessing import (
preprocess,
Preprocessor,
create_windows_from_events,
)
from braindecode.models import EEGNeX
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
from typing import Optional
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
import copy
from joblib import Parallel, delayed
# Check GPU availability
# ----------------------
#
# Identify whether a CUDA-enabled GPU is available
# and set the device accordingly.
# If using Google Colab, ensure that the runtime is set to use a GPU.
# This can be done by navigating to `Runtime` > `Change runtime type` and selecting
# `GPU` as the hardware accelerator.
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
msg = "CUDA-enabled GPU found. Training should be faster."
else:
msg = (
"No GPU found. Training will be carried out on CPU, which might be "
"slower.\n\nIf running on Google Colab, you can request a GPU runtime by"
" clicking\n`Runtime/Change runtime type` in the top bar menu, then "
"selecting 'T4 GPU'\nunder 'Hardware accelerator'."
)
print(msg)
No GPU found. Training will be carried out on CPU, which might be slower.
If running on Google Colab, you can request a GPU runtime by clicking
`Runtime/Change runtime type` in the top bar menu, then selecting 'T4 GPU'
under 'Hardware accelerator'.
# What are we decoding?
# ---------------------
#
# To start to talk about what we want to analyse, the important thing
# is to understand some basic concepts.
#
# The brain decodes the problem
# -----------------------------
#
# Broadly speaking, here *brain decoding* is the following problem:
# given brain time-series signals :math:`X \in \mathbb{R}^{C \times T}` with
# labels :math:`y \in \mathcal{Y}`, we implement a neural network :math:`f` that
# **decodes/translates** brain activity into the target label.
#
# We aim to translate recorded brain activity into its originating
# stimulus, behavior, or mental state, `King, J-R. et al. (2020) <https://lauragwilliams.github.io/d/m/CognitionAlgorithm.pdf>`__.
#
# The neural network :math:`f` applies a series of transformation layers
# (e.g., ``torch.nn.Conv2d``, ``torch.nn.Linear``, ``torch.nn.ELU``, ``torch.nn.BatchNorm2d``)
# to the data to filter, extract features, and learn embeddings
# relevant to the optimization objectiveโin other words:
#
# .. math::
#
# f_{\theta}: X \to y,
#
# where :math:`C` (``n_chans``) is the number of channels/electrodes and :math:`T` (``n_times``)
# is the temporal window length/epoch size over the interval of interest.
# Here, :math:`\theta` denotes the parameters learned by the neural network.
#
# Input/Output definition
# ---------------------------
# For the competition, the HBN-EEG (Healthy Brain Network EEG Datasets)
# dataset has ``n_chans = 129`` with the last channels as a `reference channel <https://mne.tools/stable/auto_tutorials/preprocessing/55_setting_eeg_reference.html>`_,
# and we define the window length as ``n_times = 200``, corresponding to 2-second windows.
#
# Your model should follow this definition exactly; any specific selection of channels,
# filtering, or domain-adaptation technique must be performed **within the layers of the neural network model**.
#
# In this tutorial, we will use the ``EEGNeX`` model from ``braindecode`` as an example.
# You can use any model you want, as long as it follows the input/output
# definitions above.
# Understand the task: Contrast Change Detection (CCD)
# --------------------------------------------------------
# If you are interested to get more neuroscience insight, we recommend these two references, `HBN-EEG <https://www.biorxiv.org/content/10.1101/2024.10.03.615261v2.full.pdf>`__ and `Langer, N et al. (2017) <https://www.nature.com/articles/sdata201740#Sec2>`__.
# Your task (**label**) is to predict the response time for the subject during this windows.
#
# In the Video, we have an example of recording cognitive activity:
#
# The Contrast Change Detection (CCD) task relates to
# `Steady-State Visual Evoked Potentials (SSVEP) <https://en.wikipedia.org/wiki/Steady-state_visually_evoked_potential>`__
# and `Event-Related Potentials (ERP) <https://en.wikipedia.org/wiki/Event-related_potential>`__.
#
# Algorithmically, what the subject sees during recording is:
#
# * Two flickering striped discs: one tilted left, one tilted right.
# * After a variable delay, **one disc's contrast gradually increases** **while the other decreases**.
# * They **press left or right** to indicate which disc got stronger.
# * They receive **feedback** (๐ correct / ๐ incorrect).
#
# **The task parallels SSVEP and ERP:**
#
# * The continuous flicker **tags the EEG at fixed frequencies (and harmonics)** โ SSVEP-like signals.
# * The **ramp onset**, the **button press**, and the **feedback** are **time-locked events** that yield ERP-like components.
#
# Your task (**label**) is to predict the response time for the subject during this windows.
#
# In the figure below, we have the timeline representation of the cognitive task:
#
# .. image:: https://eeg2025.github.io/assets/img/image-2.jpg
# Stimulus demonstration
# ----------------------
# .. raw:: html
#
# <div class="video-wrapper">
# <iframe src="https://www.youtube.com/embed/tOW2Vu2zHoU?start=1630"
# title="Contrast Change Detection (CCD) task demo"
# allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
# allowfullscreen></iframe>
# </div>
#
# PyTorch Dataset for the competition
# -----------------------------------
# Now, we have a Pytorch Dataset object that contains the set of recordings for the task
# `contrastChangeDetection`.
#
from eegdash.dataset import EEGChallengeDataset
from eegdash.hbn.windows import (
annotate_trials_with_target,
add_aux_anchors,
keep_only_recordings_with,
add_extras_columns,
)
# Match tests' cache layout under ~/eegdash_cache/eeg_challenge_cache
DATA_DIR = (Path.home() / "eegdash_cache" / "eeg_challenge_cache").resolve()
DATA_DIR.mkdir(parents=True, exist_ok=True)
dataset_ccd = EEGChallengeDataset(
task="contrastChangeDetection", release="R5", cache_dir=DATA_DIR, mini=True
)
# The dataset contains 20 subjects in the minirelease, and each subject has multiple recordings
# (sessions). Each recording is represented as a dataset object within the `dataset_ccd.datasets` list.
print(f"Number of recordings in the dataset: {len(dataset_ccd.datasets)}")
print(
f"Number of unique subjects in the dataset: {dataset_ccd.description['subject'].nunique()}"
)
#
# This dataset object have very rich Raw object details that can help you to
# understand better the data. The framework behind this is braindecode,
# and if you want to understand in depth what is happening, we recommend the
# braindecode github itself.
#
# We can also access the Raw object for visualization purposes, we will see just one object.
raw = dataset_ccd.datasets[0].raw # get the Raw object of the first recording
# And to download all the data all data directly, you can do:
raws = Parallel(n_jobs=-1)(delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets)
โญโโโโโโโโโโโโโโโโโโโโโโ EEG 2025 Competition Data Notice โโโโโโโโโโโโโโโโโโโโโโโฎ
โ This object loads the HBN dataset that has been preprocessed for the EEG โ
โ Challenge: โ
โ * Downsampled from 500Hz to 100Hz โ
โ * Bandpass filtered (0.5-50 Hz) โ
โ โ
โ For full preprocessing applied for competition details, see: โ
โ https://github.com/eeg2025/downsample-datasets โ
โ โ
โ The HBN dataset have some preprocessing applied by the HBN team: โ
โ * Re-reference (Cz Channel) โ
โ โ
โ IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to โ
โ what you get from EEGDashDataset directly. โ
โ If you are participating in the competition, always use โ
โ `EEGChallengeDataset` to ensure consistency with the challenge data. โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโ Source: EEGChallengeDataset โโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Number of recordings in the dataset: 60
Number of unique subjects in the dataset: 20
Downloading dataset_description.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading dataset_description.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 37.6B/s]
Downloading participants.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading participants.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 27.2B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 26.0B/s]
Downloading sub-NDARAH793FBF_task-DiaryOfAWimpyKid_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-DiaryOfAWimpyKid_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.3B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-2_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-2_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 18.8B/s]
Downloading sub-NDARAH793FBF_task-ThePresent_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-ThePresent_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 48.8B/s]
Downloading sub-NDARAH793FBF_task-symbolSearch_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-symbolSearch_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 37.1B/s]
Downloading sub-NDARAH793FBF_task-seqLearning8target_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-seqLearning8target_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 14.0B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-1_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-1_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 24.9B/s]
Downloading sub-NDARAH793FBF_task-RestingState_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-RestingState_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 28.6B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-2_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-2_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 21.5B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-3_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-3_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 19.6B/s]
Downloading sub-NDARAH793FBF_task-FunwithFractals_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-FunwithFractals_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 39.0B/s]
Downloading sub-NDARAH793FBF_task-DespicableMe_events.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-DespicableMe_events.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.2B/s]
Downloading task-seqLearning8target_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning8target_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.8B/s]
Downloading task-RestingState_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-RestingState_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 32.8B/s]
Downloading task-surroundSupp_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-surroundSupp_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 28.1B/s]
Downloading task-FunwithFractals_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-FunwithFractals_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.9B/s]
Downloading task-symbolSearch_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-symbolSearch_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 26.3B/s]
Downloading task-ThePresent_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-ThePresent_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 34.7B/s]
Downloading task-seqLearning6target_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning6target_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 20.9B/s]
Downloading task-contrastChangeDetection_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-contrastChangeDetection_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 35.7B/s]
Downloading task-DespicableMe_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DespicableMe_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 34.2B/s]
Downloading task-DiaryOfAWimpyKid_events.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DiaryOfAWimpyKid_events.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 26.3B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 20.9B/s]
Downloading sub-NDARAH793FBF_task-symbolSearch_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-symbolSearch_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 35.2B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-2_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-2_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 26.9B/s]
Downloading sub-NDARAH793FBF_task-DiaryOfAWimpyKid_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-DiaryOfAWimpyKid_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 38.7B/s]
Downloading sub-NDARAH793FBF_task-DespicableMe_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-DespicableMe_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.3B/s]
Downloading sub-NDARAH793FBF_task-seqLearning8target_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-seqLearning8target_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 24.4B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-3_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-3_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 36.8B/s]
Downloading sub-NDARAH793FBF_task-RestingState_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-RestingState_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 25.9B/s]
Downloading sub-NDARAH793FBF_task-FunwithFractals_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-FunwithFractals_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.7B/s]
Downloading sub-NDARAH793FBF_task-ThePresent_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-ThePresent_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 20.6B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-1_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-1_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 32.4B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-2_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-2_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 34.8B/s]
Downloading task-symbolSearch_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-symbolSearch_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 35.0B/s]
Downloading task-surroundSupp_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-surroundSupp_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 36.6B/s]
Downloading task-contrastChangeDetection_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-contrastChangeDetection_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 35.4B/s]
Downloading task-seqLearning8target_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning8target_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 40.0B/s]
Downloading task-DespicableMe_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DespicableMe_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.7B/s]
Downloading task-ThePresent_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-ThePresent_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 35.5B/s]
Downloading task-FunwithFractals_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-FunwithFractals_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 23.9B/s]
Downloading task-RestingState_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-RestingState_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 33.2B/s]
Downloading task-seqLearning6target_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-seqLearning6target_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 38.7B/s]
Downloading task-DiaryOfAWimpyKid_eeg.json: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading task-DiaryOfAWimpyKid_eeg.json: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 34.4B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 31.8B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-2_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-2_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 24.5B/s]
Downloading sub-NDARAH793FBF_task-RestingState_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-RestingState_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 16.7B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-2_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-2_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 36.0B/s]
Downloading sub-NDARAH793FBF_task-ThePresent_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-ThePresent_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 13.7B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-1_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-surroundSupp_run-1_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 16.1B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-3_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-3_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 37.8B/s]
Downloading sub-NDARAH793FBF_task-DespicableMe_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-DespicableMe_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 36.4B/s]
Downloading sub-NDARAH793FBF_task-FunwithFractals_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-FunwithFractals_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 35.5B/s]
Downloading sub-NDARAH793FBF_task-seqLearning8target_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-seqLearning8target_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 19.7B/s]
Downloading sub-NDARAH793FBF_task-symbolSearch_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-symbolSearch_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 34.2B/s]
Downloading sub-NDARAH793FBF_task-DiaryOfAWimpyKid_channels.tsv: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-DiaryOfAWimpyKid_channels.tsv: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 32.0B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_eeg.bdf: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_eeg.bdf: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 4.83B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_eeg.bdf: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 4.82B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_eeg.bdf: 0%| | 0.00/1.00 [00:00<?, ?B/s]
Downloading sub-NDARAH793FBF_task-contrastChangeDetection_run-1_eeg.bdf: 100%|โโโโโโโโโโ| 1.00/1.00 [00:00<00:00, 6.71B/s]
# Alternatives for Downloading the data
# -------------------------------------
#
# You can also perform this operation with wget or the aws cli.
# These options will probably be faster!
# Please check more details in the `HBN` data webpage `HBN-EEG <https://neuromechanist.github.io/data/hbn/>`__.
# You need to download the 100Hz preprocessed data in BDF format.
#
# Example of wget for release R1
# wget https://sccn.ucsd.edu/download/eeg2025/R1_L100_bdf.zip -O R1_L100_bdf.zip
#
# Example of AWS CLI for release R1
#
# aws s3 sync s3://nmdatasets/NeurIPS25/R1_L100_bdf data/R1_L100_bdf --no-sign-request
# Create windows of interest
# -----------------------------
# So we epoch after the stimulus moment with a beginning shift of 500 ms.
EPOCH_LEN_S = 2.0
SFREQ = 100 # by definition here
transformation_offline = [
Preprocessor(
annotate_trials_with_target,
target_field="rt_from_stimulus",
epoch_length=EPOCH_LEN_S,
require_stimulus=True,
require_response=True,
apply_on_array=False,
),
Preprocessor(add_aux_anchors, apply_on_array=False),
]
preprocess(dataset_ccd, transformation_offline, n_jobs=1)
ANCHOR = "stimulus_anchor"
SHIFT_AFTER_STIM = 0.5
WINDOW_LEN = 2.0
# Keep only recordings that actually contain stimulus anchors
dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)
# Create single-interval windows (stim-locked, long enough to include the response)
single_windows = create_windows_from_events(
dataset,
mapping={ANCHOR: 0},
trial_start_offset_samples=int(SHIFT_AFTER_STIM * SFREQ), # +0.5 s
trial_stop_offset_samples=int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ), # +2.5 s
window_size_samples=int(EPOCH_LEN_S * SFREQ),
window_stride_samples=SFREQ,
preload=True,
)
# Injecting metadata into the extra mne annotation.
single_windows = add_extras_columns(
single_windows,
dataset,
desc=ANCHOR,
keys=(
"target",
"rt_from_stimulus",
"rt_from_trialstart",
"stimulus_onset",
"response_onset",
"correct",
"response_type",
),
)
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
# Inspect the label distribution
# -------------------------------
import numpy as np
from skorch.helper import SliceDataset
y_label = np.array(list(SliceDataset(single_windows, 1)))
# Plot histogram of the response times with matplotlib
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(y_label, bins=30)
ax.set_title("Response Time Distribution")
ax.set_xlabel("Response Time (s)")
ax.set_ylabel("Count")
plt.tight_layout()
plt.show()

# Split the data
# ---------------
# Extract meta information
meta_information = single_windows.get_metadata()
valid_frac = 0.1
test_frac = 0.1
seed = 2025
subjects = meta_information["subject"].unique()
train_subj, valid_test_subject = train_test_split(
subjects,
test_size=(valid_frac + test_frac),
random_state=check_random_state(seed),
shuffle=True,
)
valid_subj, test_subj = train_test_split(
valid_test_subject,
test_size=test_frac,
random_state=check_random_state(seed + 1),
shuffle=True,
)
# Sanity check
assert (set(valid_subj) | set(test_subj) | set(train_subj)) == set(subjects)
# Create train/valid/test splits for the windows
subject_split = single_windows.split("subject")
train_set = []
valid_set = []
test_set = []
for s in subject_split:
if s in train_subj:
train_set.append(subject_split[s])
elif s in valid_subj:
valid_set.append(subject_split[s])
elif s in test_subj:
test_set.append(subject_split[s])
train_set = BaseConcatDataset(train_set)
valid_set = BaseConcatDataset(valid_set)
test_set = BaseConcatDataset(test_set)
print("Number of examples in each split in the minirelease")
print(f"Train:\t{len(train_set)}")
print(f"Valid:\t{len(valid_set)}")
print(f"Test:\t{len(test_set)}")
Number of examples in each split in the minirelease
Train: 981
Valid: 183
Test: 50
# Create dataloaders
# -------------------
batch_size = 128
# Set num_workers to 0 to avoid multiprocessing issues in notebooks/tutorials
num_workers = 0
train_loader = DataLoader(
train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
valid_loader = DataLoader(
valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
test_loader = DataLoader(
test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
# Build the model
# -----------------
# For neural network models, **to start**, we suggest using `braindecode models <https://braindecode.org/1.2/models/models_table.html>`__ zoo.
# We have implemented several different models for decoding the brain timeseries.
# Your team's responsibility is to develop a PyTorch module that receives the three-dimensional (`batch`, `n_chans`, `n_times`)
# input and outputs the contrastive response time.
# **You can use any model you want**, as long as it follows the input/output
# definitions above.
model = EEGNeX(
n_chans=129, # 129 channels
n_outputs=1, # 1 output for regression
n_times=200, # 2 seconds
sfreq=100, # sample frequency 100 Hz
)
print(model)
model.to(device)
/home/runner/work/EEGDash/EEGDash/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:543: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at /pytorch/aten/src/ATen/native/Convolution.cpp:1031.)
return F.conv2d(
================================================================================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape Param # Kernel Shape
================================================================================================================================================================
EEGNeX (EEGNeX) [1, 129, 200] [1, 1] -- --
โโSequential (block_1): 1-1 [1, 129, 200] [1, 8, 129, 200] -- --
โ โโRearrange (0): 2-1 [1, 129, 200] [1, 1, 129, 200] -- --
โ โโConv2d (1): 2-2 [1, 1, 129, 200] [1, 8, 129, 200] 512 [1, 64]
โ โโBatchNorm2d (2): 2-3 [1, 8, 129, 200] [1, 8, 129, 200] 16 --
โโSequential (block_2): 1-2 [1, 8, 129, 200] [1, 32, 129, 200] -- --
โ โโConv2d (0): 2-4 [1, 8, 129, 200] [1, 32, 129, 200] 16,384 [1, 64]
โ โโBatchNorm2d (1): 2-5 [1, 32, 129, 200] [1, 32, 129, 200] 64 --
โโSequential (block_3): 1-3 [1, 32, 129, 200] [1, 64, 1, 50] -- --
โ โโParametrizedConv2dWithConstraint (0): 2-6 [1, 32, 129, 200] [1, 64, 1, 200] -- [129, 1]
โ โ โโModuleDict (parametrizations): 3-1 -- -- 8,256 --
โ โโBatchNorm2d (1): 2-7 [1, 64, 1, 200] [1, 64, 1, 200] 128 --
โ โโELU (2): 2-8 [1, 64, 1, 200] [1, 64, 1, 200] -- --
โ โโAvgPool2d (3): 2-9 [1, 64, 1, 200] [1, 64, 1, 50] -- [1, 4]
โ โโDropout (4): 2-10 [1, 64, 1, 50] [1, 64, 1, 50] -- --
โโSequential (block_4): 1-4 [1, 64, 1, 50] [1, 32, 1, 50] -- --
โ โโConv2d (0): 2-11 [1, 64, 1, 50] [1, 32, 1, 50] 32,768 [1, 16]
โ โโBatchNorm2d (1): 2-12 [1, 32, 1, 50] [1, 32, 1, 50] 64 --
โโSequential (block_5): 1-5 [1, 32, 1, 50] [1, 48] -- --
โ โโConv2d (0): 2-13 [1, 32, 1, 50] [1, 8, 1, 50] 4,096 [1, 16]
โ โโBatchNorm2d (1): 2-14 [1, 8, 1, 50] [1, 8, 1, 50] 16 --
โ โโELU (2): 2-15 [1, 8, 1, 50] [1, 8, 1, 50] -- --
โ โโAvgPool2d (3): 2-16 [1, 8, 1, 50] [1, 8, 1, 6] -- [1, 8]
โ โโDropout (4): 2-17 [1, 8, 1, 6] [1, 8, 1, 6] -- --
โ โโFlatten (5): 2-18 [1, 8, 1, 6] [1, 48] -- --
โโParametrizedLinearWithConstraint (final_layer): 1-6 [1, 48] [1, 1] 1 --
โ โโModuleDict (parametrizations): 2-19 -- -- -- --
โ โ โโParametrizationList (weight): 3-2 -- [1, 48] 48 --
================================================================================================================================================================
Total params: 62,353
Trainable params: 62,353
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 437.76
================================================================================================================================================================
Input size (MB): 0.10
Forward/backward pass size (MB): 16.65
Params size (MB): 0.22
Estimated Total Size (MB): 16.97
================================================================================================================================================================
EEGNeX(
(block_1): Sequential(
(0): Rearrange('batch ch time -> batch 1 ch time')
(1): Conv2d(1, 8, kernel_size=(1, 64), stride=(1, 1), padding=same, bias=False)
(2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(block_2): Sequential(
(0): Conv2d(8, 32, kernel_size=(1, 64), stride=(1, 1), padding=same, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(block_3): Sequential(
(0): ParametrizedConv2dWithConstraint(
32, 64, kernel_size=(129, 1), stride=(1, 1), groups=32, bias=False
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): MaxNormParametrize()
)
)
)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ELU(alpha=1.0)
(3): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
(4): Dropout(p=0.5, inplace=False)
)
(block_4): Sequential(
(0): Conv2d(64, 32, kernel_size=(1, 16), stride=(1, 1), padding=same, dilation=(1, 2), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(block_5): Sequential(
(0): Conv2d(32, 8, kernel_size=(1, 16), stride=(1, 1), padding=same, dilation=(1, 4), bias=False)
(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ELU(alpha=1.0)
(3): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 1))
(4): Dropout(p=0.5, inplace=False)
(5): Flatten(start_dim=1, end_dim=-1)
)
(final_layer): ParametrizedLinearWithConstraint(
in_features=48, out_features=1, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): MaxNormParametrize()
)
)
)
)
# Define training and validation functions
# -------------------------------------------
# The rest is our classic PyTorch/torch lighting/skorch training pipeline,
# you can use any training framework you want.
# We provide a simple training and validation loop below.
#
def train_one_epoch(
dataloader: DataLoader,
model: Module,
loss_fn,
optimizer,
scheduler: Optional[LRScheduler],
epoch: int,
device,
print_batch_stats: bool = True,
):
model.train()
total_loss = 0.0
sum_sq_err = 0.0
n_samples = 0
progress_bar = tqdm(
enumerate(dataloader), total=len(dataloader), disable=not print_batch_stats
)
for batch_idx, batch in progress_bar:
# Support datasets that may return (X, y) or (X, y, ...)
X, y = batch[0], batch[1]
X, y = X.to(device).float(), y.to(device).float()
optimizer.zero_grad(set_to_none=True)
preds = model(X)
loss = loss_fn(preds, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
# Flatten to 1D for regression metrics and accumulate squared error
preds_flat = preds.detach().view(-1)
y_flat = y.detach().view(-1)
sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
n_samples += y_flat.numel()
if print_batch_stats:
running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
progress_bar.set_description(
f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)}, "
f"Loss: {loss.item():.6f}, RMSE: {running_rmse:.6f}"
)
if scheduler is not None:
scheduler.step()
avg_loss = total_loss / len(dataloader)
rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
return avg_loss, rmse
@torch.no_grad()
def valid_model(
dataloader: DataLoader,
model: Module,
loss_fn,
device,
print_batch_stats: bool = True,
):
model.eval()
total_loss = 0.0
sum_sq_err = 0.0
n_batches = len(dataloader)
n_samples = 0
iterator = tqdm(
enumerate(dataloader), total=n_batches, disable=not print_batch_stats
)
for batch_idx, batch in iterator:
# Supports (X, y) or (X, y, ...)
X, y = batch[0], batch[1]
X, y = X.to(device).float(), y.to(device).float()
preds = model(X)
batch_loss = loss_fn(preds, y).item()
total_loss += batch_loss
preds_flat = preds.detach().view(-1)
y_flat = y.detach().view(-1)
sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
n_samples += y_flat.numel()
if print_batch_stats:
running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
iterator.set_description(
f"Val Batch {batch_idx + 1}/{n_batches}, "
f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}"
)
avg_loss = total_loss / n_batches if n_batches else float("nan")
rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}\n")
return avg_loss, rmse
# Train the model
# ------------------
lr = 1e-3
weight_decay = 1e-5
n_epochs = (
5 # For demonstration purposes, we use just 5 epochs here. You can increase this.
)
early_stopping_patience = 50
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs - 1)
loss_fn = torch.nn.MSELoss()
patience = 5
min_delta = 1e-4
best_rmse = float("inf")
epochs_no_improve = 0
best_state, best_epoch = None, None
for epoch in range(1, n_epochs + 1):
print(f"Epoch {epoch}/{n_epochs}: ", end="")
train_loss, train_rmse = train_one_epoch(
train_loader, model, loss_fn, optimizer, scheduler, epoch, device
)
val_loss, val_rmse = valid_model(test_loader, model, loss_fn, device)
print(
f"Train RMSE: {train_rmse:.6f}, "
f"Average Train Loss: {train_loss:.6f}, "
f"Val RMSE: {val_rmse:.6f}, "
f"Average Val Loss: {val_loss:.6f}"
)
if val_rmse < best_rmse - min_delta:
best_rmse = val_rmse
best_state = copy.deepcopy(model.state_dict())
best_epoch = epoch
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(
f"Early stopping at epoch {epoch}. Best Val RMSE: {best_rmse:.6f} (epoch {best_epoch})"
)
break
if best_state is not None:
model.load_state_dict(best_state)
Epoch 1/5:
0%| | 0/8 [00:00<?, ?it/s]
Epoch 1, Batch 1/8, Loss: 2.838698, RMSE: 1.684844: 0%| | 0/8 [00:06<?, ?it/s]
Epoch 1, Batch 1/8, Loss: 2.838698, RMSE: 1.684844: 12%|โโ | 1/8 [00:06<00:45, 6.55s/it]
Epoch 1, Batch 2/8, Loss: 2.693015, RMSE: 1.663086: 12%|โโ | 1/8 [00:12<00:45, 6.55s/it]
Epoch 1, Batch 2/8, Loss: 2.693015, RMSE: 1.663086: 25%|โโโ | 2/8 [00:12<00:38, 6.39s/it]
Epoch 1, Batch 3/8, Loss: 2.785443, RMSE: 1.665048: 25%|โโโ | 2/8 [00:19<00:38, 6.39s/it]
Epoch 1, Batch 3/8, Loss: 2.785443, RMSE: 1.665048: 38%|โโโโ | 3/8 [00:19<00:31, 6.35s/it]
Epoch 1, Batch 4/8, Loss: 2.746054, RMSE: 1.663070: 38%|โโโโ | 3/8 [00:25<00:31, 6.35s/it]
Epoch 1, Batch 4/8, Loss: 2.746054, RMSE: 1.663070: 50%|โโโโโ | 4/8 [00:25<00:25, 6.27s/it]
Epoch 1, Batch 5/8, Loss: 2.866567, RMSE: 1.669118: 50%|โโโโโ | 4/8 [00:31<00:25, 6.27s/it]
Epoch 1, Batch 5/8, Loss: 2.866567, RMSE: 1.669118: 62%|โโโโโโโ | 5/8 [00:31<00:18, 6.22s/it]
Epoch 1, Batch 6/8, Loss: 2.627282, RMSE: 1.661177: 62%|โโโโโโโ | 5/8 [00:37<00:18, 6.22s/it]
Epoch 1, Batch 6/8, Loss: 2.627282, RMSE: 1.661177: 75%|โโโโโโโโ | 6/8 [00:37<00:12, 6.18s/it]
Epoch 1, Batch 7/8, Loss: 2.536175, RMSE: 1.651546: 75%|โโโโโโโโ | 6/8 [00:43<00:12, 6.18s/it]
Epoch 1, Batch 7/8, Loss: 2.536175, RMSE: 1.651546: 88%|โโโโโโโโโ | 7/8 [00:43<00:06, 6.17s/it]
Epoch 1, Batch 8/8, Loss: 2.386838, RMSE: 1.642583: 88%|โโโโโโโโโ | 7/8 [00:47<00:06, 6.17s/it]
Epoch 1, Batch 8/8, Loss: 2.386838, RMSE: 1.642583: 100%|โโโโโโโโโโ| 8/8 [00:47<00:00, 5.52s/it]
Epoch 1, Batch 8/8, Loss: 2.386838, RMSE: 1.642583: 100%|โโโโโโโโโโ| 8/8 [00:47<00:00, 5.97s/it]
0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.766312, RMSE: 1.663223: 0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.766312, RMSE: 1.663223: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.93it/s]
Val Batch 1/1, Loss: 2.766312, RMSE: 1.663223: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.93it/s]
Val RMSE: 1.663223, Val Loss: 2.766312
Train RMSE: 1.642583, Average Train Loss: 2.685009, Val RMSE: 1.663223, Average Val Loss: 2.766312
Epoch 2/5:
0%| | 0/8 [00:00<?, ?it/s]
Epoch 2, Batch 1/8, Loss: 2.276956, RMSE: 1.508959: 0%| | 0/8 [00:06<?, ?it/s]
Epoch 2, Batch 1/8, Loss: 2.276956, RMSE: 1.508959: 12%|โโ | 1/8 [00:06<00:43, 6.22s/it]
Epoch 2, Batch 2/8, Loss: 2.429497, RMSE: 1.534023: 12%|โโ | 1/8 [00:12<00:43, 6.22s/it]
Epoch 2, Batch 2/8, Loss: 2.429497, RMSE: 1.534023: 25%|โโโ | 2/8 [00:12<00:37, 6.20s/it]
Epoch 2, Batch 3/8, Loss: 2.399875, RMSE: 1.539083: 25%|โโโ | 2/8 [00:18<00:37, 6.20s/it]
Epoch 2, Batch 3/8, Loss: 2.399875, RMSE: 1.539083: 38%|โโโโ | 3/8 [00:18<00:30, 6.14s/it]
Epoch 2, Batch 4/8, Loss: 2.270662, RMSE: 1.531094: 38%|โโโโ | 3/8 [00:24<00:30, 6.14s/it]
Epoch 2, Batch 4/8, Loss: 2.270662, RMSE: 1.531094: 50%|โโโโโ | 4/8 [00:24<00:24, 6.14s/it]
Epoch 2, Batch 5/8, Loss: 2.045108, RMSE: 1.511430: 50%|โโโโโ | 4/8 [00:30<00:24, 6.14s/it]
Epoch 2, Batch 5/8, Loss: 2.045108, RMSE: 1.511430: 62%|โโโโโโโ | 5/8 [00:30<00:18, 6.17s/it]
Epoch 2, Batch 6/8, Loss: 2.020219, RMSE: 1.496792: 62%|โโโโโโโ | 5/8 [00:37<00:18, 6.17s/it]
Epoch 2, Batch 6/8, Loss: 2.020219, RMSE: 1.496792: 75%|โโโโโโโโ | 6/8 [00:37<00:12, 6.18s/it]
Epoch 2, Batch 7/8, Loss: 1.876298, RMSE: 1.479315: 75%|โโโโโโโโ | 6/8 [00:43<00:12, 6.18s/it]
Epoch 2, Batch 7/8, Loss: 1.876298, RMSE: 1.479315: 88%|โโโโโโโโโ | 7/8 [00:43<00:06, 6.17s/it]
Epoch 2, Batch 8/8, Loss: 1.704226, RMSE: 1.465068: 88%|โโโโโโโโโ | 7/8 [00:47<00:06, 6.17s/it]
Epoch 2, Batch 8/8, Loss: 1.704226, RMSE: 1.465068: 100%|โโโโโโโโโโ| 8/8 [00:47<00:00, 5.50s/it]
Epoch 2, Batch 8/8, Loss: 1.704226, RMSE: 1.465068: 100%|โโโโโโโโโโ| 8/8 [00:47<00:00, 5.91s/it]
0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.393699, RMSE: 1.547158: 0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.393699, RMSE: 1.547158: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.97it/s]
Val Batch 1/1, Loss: 2.393699, RMSE: 1.547158: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.97it/s]
Val RMSE: 1.547158, Val Loss: 2.393699
Train RMSE: 1.465068, Average Train Loss: 2.127855, Val RMSE: 1.547158, Average Val Loss: 2.393699
Epoch 3/5:
0%| | 0/8 [00:00<?, ?it/s]
Epoch 3, Batch 1/8, Loss: 1.526826, RMSE: 1.235648: 0%| | 0/8 [00:06<?, ?it/s]
Epoch 3, Batch 1/8, Loss: 1.526826, RMSE: 1.235648: 12%|โโ | 1/8 [00:06<00:42, 6.08s/it]
Epoch 3, Batch 2/8, Loss: 1.024627, RMSE: 1.129481: 12%|โโ | 1/8 [00:12<00:42, 6.08s/it]
Epoch 3, Batch 2/8, Loss: 1.024627, RMSE: 1.129481: 25%|โโโ | 2/8 [00:12<00:36, 6.08s/it]
Epoch 3, Batch 3/8, Loss: 1.170787, RMSE: 1.113888: 25%|โโโ | 2/8 [00:18<00:36, 6.08s/it]
Epoch 3, Batch 3/8, Loss: 1.170787, RMSE: 1.113888: 38%|โโโโ | 3/8 [00:18<00:30, 6.09s/it]
Epoch 3, Batch 4/8, Loss: 1.147321, RMSE: 1.103354: 38%|โโโโ | 3/8 [00:24<00:30, 6.09s/it]
Epoch 3, Batch 4/8, Loss: 1.147321, RMSE: 1.103354: 50%|โโโโโ | 4/8 [00:24<00:24, 6.09s/it]
Epoch 3, Batch 5/8, Loss: 1.070738, RMSE: 1.089982: 50%|โโโโโ | 4/8 [00:30<00:24, 6.09s/it]
Epoch 3, Batch 5/8, Loss: 1.070738, RMSE: 1.089982: 62%|โโโโโโโ | 5/8 [00:30<00:18, 6.08s/it]
Epoch 3, Batch 6/8, Loss: 0.955001, RMSE: 1.072015: 62%|โโโโโโโ | 5/8 [00:36<00:18, 6.08s/it]
Epoch 3, Batch 6/8, Loss: 0.955001, RMSE: 1.072015: 75%|โโโโโโโโ | 6/8 [00:36<00:12, 6.10s/it]
Epoch 3, Batch 7/8, Loss: 0.711222, RMSE: 1.042423: 75%|โโโโโโโโ | 6/8 [00:42<00:12, 6.10s/it]
Epoch 3, Batch 7/8, Loss: 0.711222, RMSE: 1.042423: 88%|โโโโโโโโโ | 7/8 [00:42<00:06, 6.11s/it]
Epoch 3, Batch 8/8, Loss: 0.684565, RMSE: 1.025577: 88%|โโโโโโโโโ | 7/8 [00:46<00:06, 6.11s/it]
Epoch 3, Batch 8/8, Loss: 0.684565, RMSE: 1.025577: 100%|โโโโโโโโโโ| 8/8 [00:46<00:00, 5.48s/it]
Epoch 3, Batch 8/8, Loss: 0.684565, RMSE: 1.025577: 100%|โโโโโโโโโโ| 8/8 [00:46<00:00, 5.85s/it]
0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 1.588087, RMSE: 1.260193: 0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 1.588087, RMSE: 1.260193: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.98it/s]
Val Batch 1/1, Loss: 1.588087, RMSE: 1.260193: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.98it/s]
Val RMSE: 1.260193, Val Loss: 1.588087
Train RMSE: 1.025577, Average Train Loss: 1.036386, Val RMSE: 1.260193, Average Val Loss: 1.588087
Epoch 4/5:
0%| | 0/8 [00:00<?, ?it/s]
Epoch 4, Batch 1/8, Loss: 0.620081, RMSE: 0.787452: 0%| | 0/8 [00:06<?, ?it/s]
Epoch 4, Batch 1/8, Loss: 0.620081, RMSE: 0.787452: 12%|โโ | 1/8 [00:06<00:42, 6.08s/it]
Epoch 4, Batch 2/8, Loss: 0.546702, RMSE: 0.763801: 12%|โโ | 1/8 [00:12<00:42, 6.08s/it]
Epoch 4, Batch 2/8, Loss: 0.546702, RMSE: 0.763801: 25%|โโโ | 2/8 [00:12<00:36, 6.11s/it]
Epoch 4, Batch 3/8, Loss: 0.606916, RMSE: 0.768917: 25%|โโโ | 2/8 [00:18<00:36, 6.11s/it]
Epoch 4, Batch 3/8, Loss: 0.606916, RMSE: 0.768917: 38%|โโโโ | 3/8 [00:18<00:30, 6.14s/it]
Epoch 4, Batch 4/8, Loss: 0.515725, RMSE: 0.756542: 38%|โโโโ | 3/8 [00:24<00:30, 6.14s/it]
Epoch 4, Batch 4/8, Loss: 0.515725, RMSE: 0.756542: 50%|โโโโโ | 4/8 [00:24<00:24, 6.17s/it]
Epoch 4, Batch 5/8, Loss: 0.634652, RMSE: 0.764732: 50%|โโโโโ | 4/8 [00:30<00:24, 6.17s/it]
Epoch 4, Batch 5/8, Loss: 0.634652, RMSE: 0.764732: 62%|โโโโโโโ | 5/8 [00:30<00:18, 6.16s/it]
Epoch 4, Batch 6/8, Loss: 0.627709, RMSE: 0.769392: 62%|โโโโโโโ | 5/8 [00:36<00:18, 6.16s/it]
Epoch 4, Batch 6/8, Loss: 0.627709, RMSE: 0.769392: 75%|โโโโโโโโ | 6/8 [00:36<00:12, 6.14s/it]
Epoch 4, Batch 7/8, Loss: 0.532959, RMSE: 0.763895: 75%|โโโโโโโโ | 6/8 [00:42<00:12, 6.14s/it]
Epoch 4, Batch 7/8, Loss: 0.532959, RMSE: 0.763895: 88%|โโโโโโโโโ | 7/8 [00:42<00:06, 6.12s/it]
Epoch 4, Batch 8/8, Loss: 0.447731, RMSE: 0.756153: 88%|โโโโโโโโโ | 7/8 [00:46<00:06, 6.12s/it]
Epoch 4, Batch 8/8, Loss: 0.447731, RMSE: 0.756153: 100%|โโโโโโโโโโ| 8/8 [00:46<00:00, 5.44s/it]
Epoch 4, Batch 8/8, Loss: 0.447731, RMSE: 0.756153: 100%|โโโโโโโโโโ| 8/8 [00:46<00:00, 5.86s/it]
0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 1.018507, RMSE: 1.009211: 0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 1.018507, RMSE: 1.009211: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.98it/s]
Val Batch 1/1, Loss: 1.018507, RMSE: 1.009211: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.98it/s]
Val RMSE: 1.009211, Val Loss: 1.018507
Train RMSE: 0.756153, Average Train Loss: 0.566559, Val RMSE: 1.009211, Average Val Loss: 1.018507
Epoch 5/5:
0%| | 0/8 [00:00<?, ?it/s]
Epoch 5, Batch 1/8, Loss: 0.483640, RMSE: 0.695443: 0%| | 0/8 [00:06<?, ?it/s]
Epoch 5, Batch 1/8, Loss: 0.483640, RMSE: 0.695443: 12%|โโ | 1/8 [00:06<00:42, 6.07s/it]
Epoch 5, Batch 2/8, Loss: 0.521336, RMSE: 0.708864: 12%|โโ | 1/8 [00:12<00:42, 6.07s/it]
Epoch 5, Batch 2/8, Loss: 0.521336, RMSE: 0.708864: 25%|โโโ | 2/8 [00:12<00:36, 6.09s/it]
Epoch 5, Batch 3/8, Loss: 0.483137, RMSE: 0.704299: 25%|โโโ | 2/8 [00:18<00:36, 6.09s/it]
Epoch 5, Batch 3/8, Loss: 0.483137, RMSE: 0.704299: 38%|โโโโ | 3/8 [00:18<00:30, 6.12s/it]
Epoch 5, Batch 4/8, Loss: 0.475719, RMSE: 0.700684: 38%|โโโโ | 3/8 [00:24<00:30, 6.12s/it]
Epoch 5, Batch 4/8, Loss: 0.475719, RMSE: 0.700684: 50%|โโโโโ | 4/8 [00:24<00:24, 6.14s/it]
Epoch 5, Batch 5/8, Loss: 0.526330, RMSE: 0.705714: 50%|โโโโโ | 4/8 [00:30<00:24, 6.14s/it]
Epoch 5, Batch 5/8, Loss: 0.526330, RMSE: 0.705714: 62%|โโโโโโโ | 5/8 [00:30<00:18, 6.14s/it]
Epoch 5, Batch 6/8, Loss: 0.519373, RMSE: 0.708229: 62%|โโโโโโโ | 5/8 [00:36<00:18, 6.14s/it]
Epoch 5, Batch 6/8, Loss: 0.519373, RMSE: 0.708229: 75%|โโโโโโโโ | 6/8 [00:36<00:12, 6.13s/it]
Epoch 5, Batch 7/8, Loss: 0.430840, RMSE: 0.701058: 75%|โโโโโโโโ | 6/8 [00:42<00:12, 6.13s/it]
Epoch 5, Batch 7/8, Loss: 0.430840, RMSE: 0.701058: 88%|โโโโโโโโโ | 7/8 [00:42<00:06, 6.14s/it]
Epoch 5, Batch 8/8, Loss: 0.559221, RMSE: 0.705231: 88%|โโโโโโโโโ | 7/8 [00:47<00:06, 6.14s/it]
Epoch 5, Batch 8/8, Loss: 0.559221, RMSE: 0.705231: 100%|โโโโโโโโโโ| 8/8 [00:47<00:00, 5.51s/it]
Epoch 5, Batch 8/8, Loss: 0.559221, RMSE: 0.705231: 100%|โโโโโโโโโโ| 8/8 [00:47<00:00, 5.88s/it]
0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 0.835757, RMSE: 0.914198: 0%| | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 0.835757, RMSE: 0.914198: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.93it/s]
Val Batch 1/1, Loss: 0.835757, RMSE: 0.914198: 100%|โโโโโโโโโโ| 1/1 [00:00<00:00, 1.93it/s]
Val RMSE: 0.914198, Val Loss: 0.835757
Train RMSE: 0.705231, Average Train Loss: 0.499949, Val RMSE: 0.914198, Average Val Loss: 0.835757
# Save the model
# -----------------
torch.save(model.state_dict(), "weights_challenge_1.pt")
print("Model saved as 'weights_challenge_1.pt'")
Model saved as 'weights_challenge_1.pt'
Total running time of the script: (4 minutes 58.147 seconds)
Estimated memory usage: 2920 MB