Challenge 1: Cross-Task Transfer Learning!#

Estimated reading time:8 minutes
Open In Colab

Preliminary notes#

Before we begin, I just want to make a deal with you, ok? This is a community competition with a strong open-source foundation. When I say open-source, I mean volunteer work.

So, if you see something that does not work or could be improved, first, please be kind, and we will fix it together on GitHub, okay?

The entire decoding community will only go further when we stop solving the same problems over and over again, and it starts working together.

How can we use the knowledge from one EEG Decoding task into another?#

Transfer learning is a widespread technique used in deep learning. It uses knowledge learned from one source task/domain in another target task/domain. It has been studied in depth in computer vision, natural language processing, and speech, but what about EEG brain decoding?

The cross-task transfer learning scenario in EEG decoding is remarkably underexplored compared to the development of new models, Aristimunha et al. (2023), even though it can be much more useful for real applications, see Wimpff et al. (2025), Wu et al. (2025).

Our Challenge 1 addresses a key goal in neurotechnology: decoding cognitive function from EEG using the pre-trained knowledge from another. In other words, developing models that can effectively transfer/adapt/adjust/fine-tune knowledge from passive EEG tasks to active tasks.

The ability to generalize and transfer is something critical that we believe should be focused on. To go beyond just comparing metrics numbers that are often not comparable, given the specificities of EEG, such as pre-processing, inter-subject variability, and many other unique components of this type of data.

This means your submitted model might be trained on a subset of tasks and fine-tuned on data from another condition, evaluating its capacity to generalize with task-specific fine-tuning.


Note: For simplicity purposes, we will only show how to do the decoding directly in our target task, and it is up to the teams to think about how to use the passive task to perform the pre-training.

Install dependencies#

For the challenge, we will need two significant dependencies: braindecode and eegdash. The libraries will install PyTorch, Pytorch Audio, Scikit-learn, MNE, MNE-BIDS, and many other packages necessary for the many functions.

Install dependencies on colab or your local machine, as eegdash have braindecode as a dependency. you can just run pip install eegdash.

Imports and setup#

from pathlib import Path
import torch
from braindecode.datasets import BaseConcatDataset
from braindecode.preprocessing import (
    preprocess,
    Preprocessor,
    create_windows_from_events,
)
from braindecode.models import EEGNeX
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
from typing import Optional
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
import copy
from joblib import Parallel, delayed

Check GPU availability#

Identify whether a CUDA-enabled GPU is available and set the device accordingly. If using Google Colab, ensure that the runtime is set to use a GPU. This can be done by navigating to Runtime > Change runtime type and selecting GPU as the hardware accelerator.

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    msg = "CUDA-enabled GPU found. Training should be faster."
else:
    msg = (
        "No GPU found. Training will be carried out on CPU, which might be "
        "slower.\n\nIf running on Google Colab, you can request a GPU runtime by"
        " clicking\n`Runtime/Change runtime type` in the top bar menu, then "
        "selecting 'T4 GPU'\nunder 'Hardware accelerator'."
    )
print(msg)
No GPU found. Training will be carried out on CPU, which might be slower.

If running on Google Colab, you can request a GPU runtime by clicking
`Runtime/Change runtime type` in the top bar menu, then selecting 'T4 GPU'
under 'Hardware accelerator'.
What are we decoding?

To start to talk about what we want to analyse, the important thing is to understand some basic concepts.

Broadly speaking, here brain decoding is the following problem: given brain time-series signals \(X \in \mathbb{R}^{C \times T}\) with labels \(y \in \mathcal{Y}\), we implement a neural network \(f\) that decodes/translates brain activity into the target label.

We aim to translate recorded brain activity into its originating stimulus, behavior, or mental state, King, J-R. et al. (2020).

The neural network \(f\) applies a series of transformation layers (e.g., torch.nn.Conv2d, torch.nn.Linear, torch.nn.ELU, torch.nn.BatchNorm2d) to the data to filter, extract features, and learn embeddings relevant to the optimization objectiveโ€”in other words:

\[f_{\theta}: X \to y,\]

where \(C\) (n_chans) is the number of channels/electrodes and \(T\) (n_times) is the temporal window length/epoch size over the interval of interest. Here, \(\theta\) denotes the parameters learned by the neural network.

For the competition, the HBN-EEG (Healthy Brain Network EEG Datasets) dataset has n_chans = 129 with the last channels as a reference channel, and we define the window length as n_times = 200, corresponding to 2-second windows.

Your model should follow this definition exactly; any specific selection of channels, filtering, or domain-adaptation technique must be performed within the layers of the neural network model.

In this tutorial, we will use the EEGNeX model from braindecode as an example. You can use any model you want, as long as it follows the input/output definitions above.

Understand the task: Contrast Change Detection (CCD)

If you are interested to get more neuroscience insight, we recommend these two references, HBN-EEG and Langer, N et al. (2017). Your task (label) is to predict the response time for the subject during this windows.

In the Video, we have an example of recording cognitive activity:

The Contrast Change Detection (CCD) task relates to Steady-State Visual Evoked Potentials (SSVEP) and Event-Related Potentials (ERP).

Algorithmically, what the subject sees during recording is:

  • Two flickering striped discs: one tilted left, one tilted right.

  • After a variable delay, one discโ€™s contrast gradually increases while the other decreases.

  • They press left or right to indicate which disc got stronger.

  • They receive feedback (๐Ÿ™‚ correct / ๐Ÿ™ incorrect).

The task parallels SSVEP and ERP:

  • The continuous flicker tags the EEG at fixed frequencies (and harmonics) โ†’ SSVEP-like signals.

  • The ramp onset, the button press, and the feedback are time-locked events that yield ERP-like components.

Your task (label) is to predict the response time for the subject during this windows.

https://eeg2025.github.io/assets/img/image-2.jpg
Stimulus demonstration

Now, we have a Pytorch Dataset object that contains the set of recordings for the task contrastChangeDetection.

from eegdash.dataset import EEGChallengeDataset
from eegdash.hbn.windows import (
    annotate_trials_with_target,
    add_aux_anchors,
    keep_only_recordings_with,
    add_extras_columns,
)

# Match tests' cache layout under ~/eegdash_cache/eeg_challenge_cache
DATA_DIR = (Path.home() / "eegdash_cache" / "eeg_challenge_cache").resolve()
DATA_DIR.mkdir(parents=True, exist_ok=True)

dataset_ccd = EEGChallengeDataset(
    task="contrastChangeDetection", release="R5", cache_dir=DATA_DIR, mini=True
)
# The dataset contains 20 subjects in the minirelease, and each subject has multiple recordings
# (sessions). Each recording is represented as a dataset object within the `dataset_ccd.datasets` list.
print(f"Number of recordings in the dataset: {len(dataset_ccd.datasets)}")
print(
    f"Number of unique subjects in the dataset: {dataset_ccd.description['subject'].nunique()}"
)
#
# This dataset object have very rich Raw object details that can help you to
# understand better the data. The framework behind this is braindecode,
# and if you want to understand in depth what is happening, we recommend the
# braindecode github itself.
#
# We can also access the Raw object for visualization purposes, we will see just one object.
raw = dataset_ccd.datasets[0].raw  # get the Raw object of the first recording
# And to download all the data all data directly, you can do:
raws = Parallel(n_jobs=-1)(delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets)
โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ EEG 2025 Competition Data Notice โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ This object loads the HBN dataset that has been preprocessed for the EEG     โ”‚
โ”‚ Challenge:                                                                   โ”‚
โ”‚   * Downsampled from 500Hz to 100Hz                                          โ”‚
โ”‚   * Bandpass filtered (0.5-50 Hz)                                            โ”‚
โ”‚                                                                              โ”‚
โ”‚ For full preprocessing applied for competition details, see:                 โ”‚
โ”‚   https://github.com/eeg2025/downsample-datasets                             โ”‚
โ”‚                                                                              โ”‚
โ”‚ The HBN dataset have some preprocessing applied by the HBN team:             โ”‚
โ”‚   * Re-reference (Cz Channel)                                                โ”‚
โ”‚                                                                              โ”‚
โ”‚ IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to   โ”‚
โ”‚ what you get from EEGDashDataset directly.                                   โ”‚
โ”‚ If you are participating in the competition, always use                      โ”‚
โ”‚ `EEGChallengeDataset` to ensure consistency with the challenge data.         โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Source: EEGChallengeDataset โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
Number of recordings in the dataset: 60
Number of unique subjects in the dataset: 20

Alternatives for Downloading the data#

You can also perform this operation with wget or the aws cli. These options will probably be faster! Please check more details in the HBN data webpage HBN-EEG. You need to download the 100Hz preprocessed data in BDF format.

Example of wget for release R1

wget https://sccn.ucsd.edu/download/eeg2025/R1_L100_bdf.zip -O R1_L100_bdf.zip

Example of AWS CLI for release R1

aws s3 sync s3://nmdatasets/NeurIPS25/R1_L100_bdf data/R1_L100_bdf โ€“no-sign-request

Create windows of interest#

So we epoch after the stimulus moment with a beginning shift of 500 ms.

EPOCH_LEN_S = 2.0
SFREQ = 100  # by definition here

transformation_offline = [
    Preprocessor(
        annotate_trials_with_target,
        target_field="rt_from_stimulus",
        epoch_length=EPOCH_LEN_S,
        require_stimulus=True,
        require_response=True,
        apply_on_array=False,
    ),
    Preprocessor(add_aux_anchors, apply_on_array=False),
]
preprocess(dataset_ccd, transformation_offline, n_jobs=1)

ANCHOR = "stimulus_anchor"
SHIFT_AFTER_STIM = 0.5
WINDOW_LEN = 2.0

# Keep only recordings that actually contain stimulus anchors
dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)

# Create single-interval windows (stim-locked, long enough to include the response)
single_windows = create_windows_from_events(
    dataset,
    mapping={ANCHOR: 0},
    trial_start_offset_samples=int(SHIFT_AFTER_STIM * SFREQ),  # +0.5 s
    trial_stop_offset_samples=int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),  # +2.5 s
    window_size_samples=int(EPOCH_LEN_S * SFREQ),
    window_stride_samples=SFREQ,
    preload=True,
)

# Injecting metadata into the extra mne annotation.
single_windows = add_extras_columns(
    single_windows,
    dataset,
    desc=ANCHOR,
    keys=(
        "target",
        "rt_from_stimulus",
        "rt_from_trialstart",
        "stimulus_onset",
        "response_onset",
        "correct",
        "response_type",
    ),
)
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]

Inspect the label distribution#

import numpy as np
from skorch.helper import SliceDataset

y_label = np.array(list(SliceDataset(single_windows, 1)))

# Plot histogram of the response times with matplotlib
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(y_label, bins=30)
ax.set_title("Response Time Distribution")
ax.set_xlabel("Response Time (s)")
ax.set_ylabel("Count")
plt.tight_layout()
plt.show()
Response Time Distribution

Split the data#

Extract meta information

meta_information = single_windows.get_metadata()

valid_frac = 0.1
test_frac = 0.1
seed = 2025

subjects = meta_information["subject"].unique()

train_subj, valid_test_subject = train_test_split(
    subjects,
    test_size=(valid_frac + test_frac),
    random_state=check_random_state(seed),
    shuffle=True,
)

valid_subj, test_subj = train_test_split(
    valid_test_subject,
    test_size=test_frac,
    random_state=check_random_state(seed + 1),
    shuffle=True,
)

# Sanity check
assert (set(valid_subj) | set(test_subj) | set(train_subj)) == set(subjects)

# Create train/valid/test splits for the windows
subject_split = single_windows.split("subject")
train_set = []
valid_set = []
test_set = []

for s in subject_split:
    if s in train_subj:
        train_set.append(subject_split[s])
    elif s in valid_subj:
        valid_set.append(subject_split[s])
    elif s in test_subj:
        test_set.append(subject_split[s])

train_set = BaseConcatDataset(train_set)
valid_set = BaseConcatDataset(valid_set)
test_set = BaseConcatDataset(test_set)

print("Number of examples in each split in the minirelease")
print(f"Train:\t{len(train_set)}")
print(f"Valid:\t{len(valid_set)}")
print(f"Test:\t{len(test_set)}")
Number of examples in each split in the minirelease
Train:  981
Valid:  183
Test:   50

Create dataloaders#

batch_size = 128
# Set num_workers to 0 to avoid multiprocessing issues in notebooks/tutorials
num_workers = 0

train_loader = DataLoader(
    train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
valid_loader = DataLoader(
    valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
test_loader = DataLoader(
    test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers
)

Build the model#

For neural network models, to start, we suggest using braindecode models zoo. We have implemented several different models for decoding the brain timeseries. Your teamโ€™s responsibility is to develop a PyTorch module that receives the three-dimensional (batch, n_chans, n_times) input and outputs the contrastive response time. You can use any model you want, as long as it follows the input/output definitions above.

model = EEGNeX(
    n_chans=129,  # 129 channels
    n_outputs=1,  # 1 output for regression
    n_times=200,  # 2 seconds
    sfreq=100,  # sample frequency 100 Hz
)

print(model)
model.to(device)
/home/runner/work/EEGDash/EEGDash/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:543: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at /pytorch/aten/src/ATen/native/Convolution.cpp:1027.)
  return F.conv2d(
================================================================================================================================================================
Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #                   Kernel Shape
================================================================================================================================================================
EEGNeX (EEGNeX)                                              [1, 129, 200]             [1, 1]                    --                        --
โ”œโ”€Sequential (block_1): 1-1                                  [1, 129, 200]             [1, 8, 129, 200]          --                        --
โ”‚    โ””โ”€Rearrange (0): 2-1                                    [1, 129, 200]             [1, 1, 129, 200]          --                        --
โ”‚    โ””โ”€Conv2d (1): 2-2                                       [1, 1, 129, 200]          [1, 8, 129, 200]          512                       [1, 64]
โ”‚    โ””โ”€BatchNorm2d (2): 2-3                                  [1, 8, 129, 200]          [1, 8, 129, 200]          16                        --
โ”œโ”€Sequential (block_2): 1-2                                  [1, 8, 129, 200]          [1, 32, 129, 200]         --                        --
โ”‚    โ””โ”€Conv2d (0): 2-4                                       [1, 8, 129, 200]          [1, 32, 129, 200]         16,384                    [1, 64]
โ”‚    โ””โ”€BatchNorm2d (1): 2-5                                  [1, 32, 129, 200]         [1, 32, 129, 200]         64                        --
โ”œโ”€Sequential (block_3): 1-3                                  [1, 32, 129, 200]         [1, 64, 1, 50]            --                        --
โ”‚    โ””โ”€ParametrizedConv2dWithConstraint (0): 2-6             [1, 32, 129, 200]         [1, 64, 1, 200]           --                        [129, 1]
โ”‚    โ”‚    โ””โ”€ModuleDict (parametrizations): 3-1               --                        --                        8,256                     --
โ”‚    โ””โ”€BatchNorm2d (1): 2-7                                  [1, 64, 1, 200]           [1, 64, 1, 200]           128                       --
โ”‚    โ””โ”€ELU (2): 2-8                                          [1, 64, 1, 200]           [1, 64, 1, 200]           --                        --
โ”‚    โ””โ”€AvgPool2d (3): 2-9                                    [1, 64, 1, 200]           [1, 64, 1, 50]            --                        [1, 4]
โ”‚    โ””โ”€Dropout (4): 2-10                                     [1, 64, 1, 50]            [1, 64, 1, 50]            --                        --
โ”œโ”€Sequential (block_4): 1-4                                  [1, 64, 1, 50]            [1, 32, 1, 50]            --                        --
โ”‚    โ””โ”€Conv2d (0): 2-11                                      [1, 64, 1, 50]            [1, 32, 1, 50]            32,768                    [1, 16]
โ”‚    โ””โ”€BatchNorm2d (1): 2-12                                 [1, 32, 1, 50]            [1, 32, 1, 50]            64                        --
โ”œโ”€Sequential (block_5): 1-5                                  [1, 32, 1, 50]            [1, 48]                   --                        --
โ”‚    โ””โ”€Conv2d (0): 2-13                                      [1, 32, 1, 50]            [1, 8, 1, 50]             4,096                     [1, 16]
โ”‚    โ””โ”€BatchNorm2d (1): 2-14                                 [1, 8, 1, 50]             [1, 8, 1, 50]             16                        --
โ”‚    โ””โ”€ELU (2): 2-15                                         [1, 8, 1, 50]             [1, 8, 1, 50]             --                        --
โ”‚    โ””โ”€AvgPool2d (3): 2-16                                   [1, 8, 1, 50]             [1, 8, 1, 6]              --                        [1, 8]
โ”‚    โ””โ”€Dropout (4): 2-17                                     [1, 8, 1, 6]              [1, 8, 1, 6]              --                        --
โ”‚    โ””โ”€Flatten (5): 2-18                                     [1, 8, 1, 6]              [1, 48]                   --                        --
โ”œโ”€ParametrizedLinearWithConstraint (final_layer): 1-6        [1, 48]                   [1, 1]                    1                         --
โ”‚    โ””โ”€ModuleDict (parametrizations): 2-19                   --                        --                        --                        --
โ”‚    โ”‚    โ””โ”€ParametrizationList (weight): 3-2                --                        [1, 48]                   48                        --
================================================================================================================================================================
Total params: 62,353
Trainable params: 62,353
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 437.76
================================================================================================================================================================
Input size (MB): 0.10
Forward/backward pass size (MB): 16.65
Params size (MB): 0.22
Estimated Total Size (MB): 16.97
================================================================================================================================================================

EEGNeX(
  (block_1): Sequential(
    (0): Rearrange('batch ch time -> batch 1 ch time')
    (1): Conv2d(1, 8, kernel_size=(1, 64), stride=(1, 1), padding=same, bias=False)
    (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block_2): Sequential(
    (0): Conv2d(8, 32, kernel_size=(1, 64), stride=(1, 1), padding=same, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block_3): Sequential(
    (0): ParametrizedConv2dWithConstraint(
      32, 64, kernel_size=(129, 1), stride=(1, 1), groups=32, bias=False
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): MaxNormParametrize()
        )
      )
    )
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=(0, 1))
    (4): Dropout(p=0.5, inplace=False)
  )
  (block_4): Sequential(
    (0): Conv2d(64, 32, kernel_size=(1, 16), stride=(1, 1), padding=same, dilation=(1, 2), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block_5): Sequential(
    (0): Conv2d(32, 8, kernel_size=(1, 16), stride=(1, 1), padding=same, dilation=(1, 4), bias=False)
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=(0, 1))
    (4): Dropout(p=0.5, inplace=False)
    (5): Flatten(start_dim=1, end_dim=-1)
  )
  (final_layer): ParametrizedLinearWithConstraint(
    in_features=48, out_features=1, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): MaxNormParametrize()
      )
    )
  )
)

Define training and validation functions#

The rest is our classic PyTorch/torch lighting/skorch training pipeline, you can use any training framework you want. We provide a simple training and validation loop below.

def train_one_epoch(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    optimizer,
    scheduler: Optional[LRScheduler],
    epoch: int,
    device,
    print_batch_stats: bool = True,
):
    model.train()

    total_loss = 0.0
    sum_sq_err = 0.0
    n_samples = 0

    progress_bar = tqdm(
        enumerate(dataloader), total=len(dataloader), disable=not print_batch_stats
    )

    for batch_idx, batch in progress_bar:
        # Support datasets that may return (X, y) or (X, y, ...)
        X, y = batch[0], batch[1]
        X, y = X.to(device).float(), y.to(device).float()

        optimizer.zero_grad(set_to_none=True)
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Flatten to 1D for regression metrics and accumulate squared error
        preds_flat = preds.detach().view(-1)
        y_flat = y.detach().view(-1)
        sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
        n_samples += y_flat.numel()

        if print_batch_stats:
            running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
            progress_bar.set_description(
                f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)}, "
                f"Loss: {loss.item():.6f}, RMSE: {running_rmse:.6f}"
            )

    if scheduler is not None:
        scheduler.step()

    avg_loss = total_loss / len(dataloader)
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
    return avg_loss, rmse


@torch.no_grad()
def valid_model(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    device,
    print_batch_stats: bool = True,
):
    model.eval()

    total_loss = 0.0
    sum_sq_err = 0.0
    n_batches = len(dataloader)
    n_samples = 0

    iterator = tqdm(
        enumerate(dataloader), total=n_batches, disable=not print_batch_stats
    )

    for batch_idx, batch in iterator:
        # Supports (X, y) or (X, y, ...)
        X, y = batch[0], batch[1]
        X, y = X.to(device).float(), y.to(device).float()

        preds = model(X)
        batch_loss = loss_fn(preds, y).item()
        total_loss += batch_loss

        preds_flat = preds.detach().view(-1)
        y_flat = y.detach().view(-1)
        sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
        n_samples += y_flat.numel()

        if print_batch_stats:
            running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
            iterator.set_description(
                f"Val Batch {batch_idx + 1}/{n_batches}, "
                f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}"
            )

    avg_loss = total_loss / n_batches if n_batches else float("nan")
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5

    print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}\n")
    return avg_loss, rmse

Train the model#

lr = 1e-3
weight_decay = 1e-5
n_epochs = (
    5  # For demonstration purposes, we use just 5 epochs here. You can increase this.
)
early_stopping_patience = 50

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs - 1)
loss_fn = torch.nn.MSELoss()

patience = 5
min_delta = 1e-4
best_rmse = float("inf")
epochs_no_improve = 0
best_state, best_epoch = None, None

for epoch in range(1, n_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}: ", end="")

    train_loss, train_rmse = train_one_epoch(
        train_loader, model, loss_fn, optimizer, scheduler, epoch, device
    )
    val_loss, val_rmse = valid_model(test_loader, model, loss_fn, device)

    print(
        f"Train RMSE: {train_rmse:.6f}, "
        f"Average Train Loss: {train_loss:.6f}, "
        f"Val RMSE: {val_rmse:.6f}, "
        f"Average Val Loss: {val_loss:.6f}"
    )

    if val_rmse < best_rmse - min_delta:
        best_rmse = val_rmse
        best_state = copy.deepcopy(model.state_dict())
        best_epoch = epoch
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(
                f"Early stopping at epoch {epoch}. Best Val RMSE: {best_rmse:.6f} (epoch {best_epoch})"
            )
            break

if best_state is not None:
    model.load_state_dict(best_state)
Epoch 1/5:
  0%|          | 0/8 [00:00<?, ?it/s]
Epoch 1, Batch 1/8, Loss: 2.804945, RMSE: 1.674797:   0%|          | 0/8 [00:05<?, ?it/s]
Epoch 1, Batch 1/8, Loss: 2.804945, RMSE: 1.674797:  12%|โ–ˆโ–Ž        | 1/8 [00:05<00:41,  5.95s/it]
Epoch 1, Batch 2/8, Loss: 2.720049, RMSE: 1.662076:  12%|โ–ˆโ–Ž        | 1/8 [00:11<00:41,  5.95s/it]
Epoch 1, Batch 2/8, Loss: 2.720049, RMSE: 1.662076:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:11<00:34,  5.74s/it]
Epoch 1, Batch 3/8, Loss: 2.714117, RMSE: 1.657218:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:17<00:34,  5.74s/it]
Epoch 1, Batch 3/8, Loss: 2.714117, RMSE: 1.657218:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:17<00:28,  5.67s/it]
Epoch 1, Batch 4/8, Loss: 2.829088, RMSE: 1.663445:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:22<00:28,  5.67s/it]
Epoch 1, Batch 4/8, Loss: 2.829088, RMSE: 1.663445:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:22<00:22,  5.65s/it]
Epoch 1, Batch 5/8, Loss: 2.743809, RMSE: 1.662047:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:28<00:22,  5.65s/it]
Epoch 1, Batch 5/8, Loss: 2.743809, RMSE: 1.662047:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:28<00:16,  5.62s/it]
Epoch 1, Batch 6/8, Loss: 2.678480, RMSE: 1.657834:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:33<00:16,  5.62s/it]
Epoch 1, Batch 6/8, Loss: 2.678480, RMSE: 1.657834:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:33<00:11,  5.60s/it]
Epoch 1, Batch 7/8, Loss: 2.581510, RMSE: 1.650628:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:39<00:11,  5.60s/it]
Epoch 1, Batch 7/8, Loss: 2.581510, RMSE: 1.650628:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:39<00:05,  5.60s/it]
Epoch 1, Batch 8/8, Loss: 2.433732, RMSE: 1.642976:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:43<00:05,  5.60s/it]
Epoch 1, Batch 8/8, Loss: 2.433732, RMSE: 1.642976: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:43<00:00,  4.97s/it]
Epoch 1, Batch 8/8, Loss: 2.433732, RMSE: 1.642976: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:43<00:00,  5.39s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.774135, RMSE: 1.665573:   0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.774135, RMSE: 1.665573: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.81it/s]
Val Batch 1/1, Loss: 2.774135, RMSE: 1.665573: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.80it/s]
Val RMSE: 1.665573, Val Loss: 2.774135

Train RMSE: 1.642976, Average Train Loss: 2.688216, Val RMSE: 1.665573, Average Val Loss: 2.774135
Epoch 2/5:
  0%|          | 0/8 [00:00<?, ?it/s]
Epoch 2, Batch 1/8, Loss: 2.361676, RMSE: 1.536775:   0%|          | 0/8 [00:05<?, ?it/s]
Epoch 2, Batch 1/8, Loss: 2.361676, RMSE: 1.536775:  12%|โ–ˆโ–Ž        | 1/8 [00:05<00:38,  5.54s/it]
Epoch 2, Batch 2/8, Loss: 2.387175, RMSE: 1.540917:  12%|โ–ˆโ–Ž        | 1/8 [00:11<00:38,  5.54s/it]
Epoch 2, Batch 2/8, Loss: 2.387175, RMSE: 1.540917:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:11<00:33,  5.52s/it]
Epoch 2, Batch 3/8, Loss: 2.414732, RMSE: 1.545271:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:16<00:33,  5.52s/it]
Epoch 2, Batch 3/8, Loss: 2.414732, RMSE: 1.545271:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:16<00:27,  5.51s/it]
Epoch 2, Batch 4/8, Loss: 2.129933, RMSE: 1.524264:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:22<00:27,  5.51s/it]
Epoch 2, Batch 4/8, Loss: 2.129933, RMSE: 1.524264:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:22<00:22,  5.51s/it]
Epoch 2, Batch 5/8, Loss: 1.931249, RMSE: 1.498317:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:27<00:22,  5.51s/it]
Epoch 2, Batch 5/8, Loss: 1.931249, RMSE: 1.498317:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:27<00:16,  5.51s/it]
Epoch 2, Batch 6/8, Loss: 1.940712, RMSE: 1.481299:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:33<00:16,  5.51s/it]
Epoch 2, Batch 6/8, Loss: 1.940712, RMSE: 1.481299:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:33<00:11,  5.52s/it]
Epoch 2, Batch 7/8, Loss: 1.732118, RMSE: 1.458845:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:38<00:11,  5.52s/it]
Epoch 2, Batch 7/8, Loss: 1.732118, RMSE: 1.458845:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:38<00:05,  5.52s/it]
Epoch 2, Batch 8/8, Loss: 1.507685, RMSE: 1.440299:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:42<00:05,  5.52s/it]
Epoch 2, Batch 8/8, Loss: 1.507685, RMSE: 1.440299: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:42<00:00,  4.90s/it]
Epoch 2, Batch 8/8, Loss: 1.507685, RMSE: 1.440299: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:42<00:00,  5.27s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.367503, RMSE: 1.538669:   0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 2.367503, RMSE: 1.538669: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.77it/s]
Val Batch 1/1, Loss: 2.367503, RMSE: 1.538669: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.77it/s]
Val RMSE: 1.538669, Val Loss: 2.367503

Train RMSE: 1.440299, Average Train Loss: 2.050660, Val RMSE: 1.538669, Average Val Loss: 2.367503
Epoch 3/5:
  0%|          | 0/8 [00:00<?, ?it/s]
Epoch 3, Batch 1/8, Loss: 1.254147, RMSE: 1.119887:   0%|          | 0/8 [00:05<?, ?it/s]
Epoch 3, Batch 1/8, Loss: 1.254147, RMSE: 1.119887:  12%|โ–ˆโ–Ž        | 1/8 [00:05<00:38,  5.51s/it]
Epoch 3, Batch 2/8, Loss: 0.936626, RMSE: 1.046607:  12%|โ–ˆโ–Ž        | 1/8 [00:11<00:38,  5.51s/it]
Epoch 3, Batch 2/8, Loss: 0.936626, RMSE: 1.046607:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:11<00:33,  5.52s/it]
Epoch 3, Batch 3/8, Loss: 0.992308, RMSE: 1.030062:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:16<00:33,  5.52s/it]
Epoch 3, Batch 3/8, Loss: 0.992308, RMSE: 1.030062:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:16<00:27,  5.51s/it]
Epoch 3, Batch 4/8, Loss: 0.982872, RMSE: 1.020533:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:22<00:27,  5.51s/it]
Epoch 3, Batch 4/8, Loss: 0.982872, RMSE: 1.020533:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:22<00:22,  5.51s/it]
Epoch 3, Batch 5/8, Loss: 0.905061, RMSE: 1.007076:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:27<00:22,  5.51s/it]
Epoch 3, Batch 5/8, Loss: 0.905061, RMSE: 1.007076:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:27<00:16,  5.50s/it]
Epoch 3, Batch 6/8, Loss: 0.914558, RMSE: 0.998797:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:33<00:16,  5.50s/it]
Epoch 3, Batch 6/8, Loss: 0.914558, RMSE: 0.998797:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:33<00:11,  5.50s/it]
Epoch 3, Batch 7/8, Loss: 0.659756, RMSE: 0.974337:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:38<00:11,  5.50s/it]
Epoch 3, Batch 7/8, Loss: 0.659756, RMSE: 0.974337:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:38<00:05,  5.49s/it]
Epoch 3, Batch 8/8, Loss: 0.557627, RMSE: 0.956762:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:42<00:05,  5.49s/it]
Epoch 3, Batch 8/8, Loss: 0.557627, RMSE: 0.956762: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:42<00:00,  4.88s/it]
Epoch 3, Batch 8/8, Loss: 0.557627, RMSE: 0.956762: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:42<00:00,  5.26s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 1.447026, RMSE: 1.202924:   0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 1.447026, RMSE: 1.202924: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.80it/s]
Val Batch 1/1, Loss: 1.447026, RMSE: 1.202924: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.80it/s]
Val RMSE: 1.202924, Val Loss: 1.447026

Train RMSE: 0.956762, Average Train Loss: 0.900369, Val RMSE: 1.202924, Average Val Loss: 1.447026
Epoch 4/5:
  0%|          | 0/8 [00:00<?, ?it/s]
Epoch 4, Batch 1/8, Loss: 0.557235, RMSE: 0.746482:   0%|          | 0/8 [00:05<?, ?it/s]
Epoch 4, Batch 1/8, Loss: 0.557235, RMSE: 0.746482:  12%|โ–ˆโ–Ž        | 1/8 [00:05<00:38,  5.47s/it]
Epoch 4, Batch 2/8, Loss: 0.434160, RMSE: 0.704058:  12%|โ–ˆโ–Ž        | 1/8 [00:10<00:38,  5.47s/it]
Epoch 4, Batch 2/8, Loss: 0.434160, RMSE: 0.704058:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:10<00:32,  5.47s/it]
Epoch 4, Batch 3/8, Loss: 0.528061, RMSE: 0.711678:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:16<00:32,  5.47s/it]
Epoch 4, Batch 3/8, Loss: 0.528061, RMSE: 0.711678:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:16<00:27,  5.46s/it]
Epoch 4, Batch 4/8, Loss: 0.485393, RMSE: 0.707963:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:21<00:27,  5.46s/it]
Epoch 4, Batch 4/8, Loss: 0.485393, RMSE: 0.707963:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:21<00:21,  5.48s/it]
Epoch 4, Batch 5/8, Loss: 0.586323, RMSE: 0.719885:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:27<00:21,  5.48s/it]
Epoch 4, Batch 5/8, Loss: 0.586323, RMSE: 0.719885:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:27<00:16,  5.48s/it]
Epoch 4, Batch 6/8, Loss: 0.550922, RMSE: 0.723659:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:32<00:16,  5.48s/it]
Epoch 4, Batch 6/8, Loss: 0.550922, RMSE: 0.723659:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:32<00:10,  5.49s/it]
Epoch 4, Batch 7/8, Loss: 0.496363, RMSE: 0.720957:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:38<00:10,  5.49s/it]
Epoch 4, Batch 7/8, Loss: 0.496363, RMSE: 0.720957:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:38<00:05,  5.48s/it]
Epoch 4, Batch 8/8, Loss: 0.431091, RMSE: 0.715608:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:41<00:05,  5.48s/it]
Epoch 4, Batch 8/8, Loss: 0.431091, RMSE: 0.715608: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:41<00:00,  4.87s/it]
Epoch 4, Batch 8/8, Loss: 0.431091, RMSE: 0.715608: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:41<00:00,  5.24s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 0.943596, RMSE: 0.971389:   0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 0.943596, RMSE: 0.971389: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.79it/s]
Val Batch 1/1, Loss: 0.943596, RMSE: 0.971389: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.79it/s]
Val RMSE: 0.971389, Val Loss: 0.943596

Train RMSE: 0.715608, Average Train Loss: 0.508693, Val RMSE: 0.971389, Average Val Loss: 0.943596
Epoch 5/5:
  0%|          | 0/8 [00:00<?, ?it/s]
Epoch 5, Batch 1/8, Loss: 0.461542, RMSE: 0.679369:   0%|          | 0/8 [00:05<?, ?it/s]
Epoch 5, Batch 1/8, Loss: 0.461542, RMSE: 0.679369:  12%|โ–ˆโ–Ž        | 1/8 [00:05<00:38,  5.44s/it]
Epoch 5, Batch 2/8, Loss: 0.416909, RMSE: 0.662741:  12%|โ–ˆโ–Ž        | 1/8 [00:10<00:38,  5.44s/it]
Epoch 5, Batch 2/8, Loss: 0.416909, RMSE: 0.662741:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:10<00:32,  5.42s/it]
Epoch 5, Batch 3/8, Loss: 0.458665, RMSE: 0.667612:  25%|โ–ˆโ–ˆโ–Œ       | 2/8 [00:16<00:32,  5.42s/it]
Epoch 5, Batch 3/8, Loss: 0.458665, RMSE: 0.667612:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:16<00:27,  5.43s/it]
Epoch 5, Batch 4/8, Loss: 0.449484, RMSE: 0.668319:  38%|โ–ˆโ–ˆโ–ˆโ–Š      | 3/8 [00:21<00:27,  5.43s/it]
Epoch 5, Batch 4/8, Loss: 0.449484, RMSE: 0.668319:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:21<00:21,  5.42s/it]
Epoch 5, Batch 5/8, Loss: 0.490168, RMSE: 0.674799:  50%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ     | 4/8 [00:27<00:21,  5.42s/it]
Epoch 5, Batch 5/8, Loss: 0.490168, RMSE: 0.674799:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:27<00:16,  5.42s/it]
Epoch 5, Batch 6/8, Loss: 0.475042, RMSE: 0.677226:  62%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Ž   | 5/8 [00:32<00:16,  5.42s/it]
Epoch 5, Batch 6/8, Loss: 0.475042, RMSE: 0.677226:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:32<00:10,  5.41s/it]
Epoch 5, Batch 7/8, Loss: 0.414124, RMSE: 0.672515:  75%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ  | 6/8 [00:37<00:10,  5.41s/it]
Epoch 5, Batch 7/8, Loss: 0.414124, RMSE: 0.672515:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:37<00:05,  5.41s/it]
Epoch 5, Batch 8/8, Loss: 0.550227, RMSE: 0.678796:  88%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Š | 7/8 [00:41<00:05,  5.41s/it]
Epoch 5, Batch 8/8, Loss: 0.550227, RMSE: 0.678796: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:41<00:00,  4.83s/it]
Epoch 5, Batch 8/8, Loss: 0.550227, RMSE: 0.678796: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [00:41<00:00,  5.19s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 0.793600, RMSE: 0.890842:   0%|          | 0/1 [00:00<?, ?it/s]
Val Batch 1/1, Loss: 0.793600, RMSE: 0.890842: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.81it/s]
Val Batch 1/1, Loss: 0.793600, RMSE: 0.890842: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1/1 [00:00<00:00,  1.80it/s]
Val RMSE: 0.890842, Val Loss: 0.793600

Train RMSE: 0.678796, Average Train Loss: 0.464520, Val RMSE: 0.890842, Average Val Loss: 0.793600

Save the model#

torch.save(model.state_dict(), "weights_challenge_1.pt")
print("Model saved as 'weights_challenge_1.pt'")
Model saved as 'weights_challenge_1.pt'

Total running time of the script: (3 minutes 55.740 seconds)

Estimated memory usage: 2886 MB

Gallery generated by Sphinx-Gallery