Note
Go to the end to download the full example code or to run this example in your browser via Binder.
P3 Visual Oddball Classification#
This tutorial demonstrates using the EEGDash library with PyTorch to classify EEG responses from a visual P3 oddball paradigm.
Data Description: Dataset contains EEG recordings during a visual oddball task where:
Letters A, B, C, D, and E were presented randomly (p = .2 for each)
One letter was designated as target (oddball) for each block
Other letters served as non-targets (standard)
Participants responded whether each letter was target or non-target
Data Preprocessing:
Applies bandpass filtering (1-55 Hz)
Selects first 30 EEG channels
Downsamples to 256Hz
Creates event-based windows (0.1s to 0.6s post-stimulus)
Dataset Preparation:
Maps events into two classes (target vs. standard) using annotation names
Splits into training (80%) and test (20%) sets
Creates PyTorch DataLoaders
Model:
ShallowFBCSPNet architecture
30 input channels, 2 output classes
128-sample input windows (0.5s at 256Hz)
Training:
Adamax optimizer with learning rate decay
A few training epochs (configurable)
Reports accuracy on train and test sets
## Data Retrieval Using EEGDash
The P3 oddball dataset is fetched from the EEGDash API.
from pathlib import Path
import os
from eegdash import EEGDash, EEGDashDataset
from eegdash.paths import get_default_cache_dir
CACHE_DIR = Path(get_default_cache_dir()).resolve()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
DATASET_ID = "ds005863"
TASK = "visualoddball"
RECORD_LIMIT = 20
eegdash = EEGDash()
records = eegdash.find({"dataset": DATASET_ID, "task": TASK}, limit=RECORD_LIMIT)
if not records:
records = eegdash.find(
{"task": {"$regex": "oddball", "$options": "i"}}, limit=RECORD_LIMIT
)
if records:
dataset_id = records[0].get("dataset")
if dataset_id:
records = [rec for rec in records if rec.get("dataset") == dataset_id]
if not records:
raise RuntimeError("No oddball task records found from the API.")
dataset_concat = EEGDashDataset(cache_dir=CACHE_DIR, records=records)
dataset_concat.download_all()
## Data Preprocessing Using Braindecode
[Braindecode](https://braindecode.org/) provides powerful tools for EEG data preprocessing and analysis. Our implementation processes EEG data with these key steps:
Channel Selection & Signal Processing: - Selecting first 30 EEG channels - Bandpass filtering between 1-55 Hz - Downsampling from 1024Hz to 256Hz
Event Processing: - Map target vs. standard events based on annotation labels (e.g., Target/NonTarget). - Response-only events are ignored by the mapping.
Window Creation:
Window duration: 1s
Efficient memory usage with on-demand loading
import logging
import warnings
import mne
import numpy as np
from braindecode.preprocessing import (
Preprocessor,
create_windows_from_events,
preprocess,
)
mne.set_log_level("ERROR")
logging.getLogger("joblib").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
# BrainDecode preprocessors
preprocessors = [
Preprocessor(
"pick_channels", ch_names=dataset_concat.datasets[0].raw.ch_names[:30]
),
Preprocessor("resample", sfreq=256),
Preprocessor("filter", l_freq=1, h_freq=55),
]
preprocess(dataset_concat, preprocessors)
# Extract windows
event_mapping = {
"Target": 1,
"NonTarget": 0,
"target": 1,
"standard": 0,
"oddball": 1,
"3": 1,
"4": 1,
"6": 0,
"7": 0,
}
windows_ds = create_windows_from_events(
dataset_concat,
trial_start_offset_samples=26,
trial_stop_offset_samples=154,
preload=False,
window_size_samples=None,
window_stride_samples=None,
drop_bad_windows=True,
mapping=event_mapping,
)
print(f"\nAll files processed, total number of windows: {len(windows_ds)}")
print(f"Window shape: {windows_ds[0][0].shape}")
## Creating Training and Test Sets
The data preparation pipeline consists of these key steps:
Data Extraction - Windows are automatically labeled (0=standard, 1=oddball) by the P3OddballPreprocessor.
Train-Test Split - Using sklearn’s train_test_split with: - 80-20 split ratio - Stratified sampling to maintain class proportions - Fixed random seed for reproducibility
PyTorch Data Preparation - Converting to tensors and creating DataLoader objects for mini-batch training.
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
# Set random seed for reproducibility
random_state = 42
torch.manual_seed(random_state)
np.random.seed(random_state)
# Extract data and labels using array operations
data = np.stack([windows_ds[i][0] for i in range(len(windows_ds))])
labels = np.array([windows_ds[i][1] for i in range(len(windows_ds))])
# Print dataset information
print(f"Dataset size: {len(data)}")
print(f"Data shape: {data.shape}")
print("Distribution of labels:", np.unique(labels, return_counts=True))
print("Label meanings: 0=standard, 1=oddball")
# Split into train and test sets
train_indices, test_indices = train_test_split(
range(len(data)), test_size=0.2, stratify=labels, random_state=random_state
)
# Convert to PyTorch tensors
X_train = torch.FloatTensor(data[train_indices])
X_test = torch.FloatTensor(data[test_indices])
y_train = torch.LongTensor(labels[train_indices])
y_test = torch.LongTensor(labels[test_indices])
# Create data loaders
dataset_train = TensorDataset(X_train, y_train)
dataset_test = TensorDataset(X_test, y_test)
train_loader = DataLoader(dataset_train, batch_size=8, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=8, shuffle=True)
# Print dataset information
print("\nDataset size:")
print(f"Training set: {X_train.shape}, labels: {y_train.shape}")
print(f"Test set: {X_test.shape}, labels: {y_test.shape}")
print("\nProportion of samples of each class in training set:")
for label in np.unique(labels):
ratio = np.mean(y_train.numpy() == label)
print(f"Category {label}: {ratio:.3f}")
## Create Model
The model is a shallow convolutional neural network (ShallowFBCSPNet) with: - 30 input channels (EEG channels) - 2 output classes (oddball, standard) - 128-sample input windows (0.5s at 256Hz)
This architecture is particularly effective for EEG classification tasks, incorporating frequency-band specific spatial patterns.
from torchinfo import summary
from braindecode.models import ShallowFBCSPNet
model = ShallowFBCSPNet(
in_chans=30,
n_classes=2,
input_window_samples=128, # 0.5s at 256Hz
final_conv_length="auto",
)
summary(model, input_size=(1, 30, 128))
## Model Training and Evaluation
The training pipeline consists of:
Optimization Setup: - Adamax optimizer with learning rate 0.002 - Weight decay for regularization - Learning rate scheduler
Training Process: - 5 epochs of training - Mini-batch processing - Cross-entropy loss function
Evaluation: - Accuracy tracking for both training and test sets - Batch normalization applied to input data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.Adamax(model.parameters(), lr=0.002, weight_decay=0.005)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)
def normalize_data(x):
mean = x.mean(dim=2, keepdim=True)
std = x.std(dim=2, keepdim=True) + 1e-7
x = (x - mean) / std
x = x.to(device=device, dtype=torch.float32)
return x
print("\nStart training...")
epochs = int(os.getenv("EEGDASH_EPOCHS", "2"))
for e in range(epochs):
model.train()
correct_train = 0
for t, (x, y) in enumerate(train_loader):
scores = model(normalize_data(x))
y = y.to(device=device, dtype=torch.long)
_, preds = scores.max(1)
correct_train += (preds == y).sum() / len(dataset_train)
loss = F.cross_entropy(scores, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
correct_test = 0
with torch.no_grad():
for t, (x, y) in enumerate(test_loader):
scores = model(normalize_data(x))
y = y.to(device=device, dtype=torch.long)
_, preds = scores.max(1)
correct_test += (preds == y).sum() / len(dataset_test)
print(
f"epoch {e + 1}, training accuracy: {correct_train:.3f}, test accuracy: {correct_test:.3f}"
)