Note
Go to the end to download the full example code or to run this example in your browser via Binder.
Parallelise EEGDash feature extraction#
Goal: scale eegdash.features.extract_features() across multiple cores
on one node by tuning n_jobs and batch_size, then persist the
result so re-runs are free.
How-to recipe (kind=how-to, no PRIMM). Synthetic data so it runs locally in under 90 s on a 4-core CPU.
Goal#
Cut wall-clock for feature extraction on a single node, keep memory below the cgroup limit, and avoid recomputing the Welch PSD across job restarts. Tied to Cisotto and Chicco (2024, doi:10.7717/peerj-cs.2256) Tip 6 (reuse cached spectra) and the joblib parallelism documented in scikit-learn (Pedregosa et al., 2011).
Learning objectives#
Choose
n_jobsfrom$SLURM_CPUS_PER_TASKinstead of-1.Pick a
batch_sizethat keeps each worker busy without OOMing.Persist the feature table once and reload it across jobs.
Read a small scaling table and stop adding workers when it pays nothing.
Prerequisites#
Completed /auto_examples/tutorials/40_features/plot_40_first_features.
Read /auto_examples/how_to/how_to_use_hpc_cache so you know where
EEGDASH_CACHEandEEGDASH_FEATURES_CACHEshould point.Local Python with
braindecode,mne,joblib,pyarrow.
import os
os.environ.setdefault("PYTHONWARNINGS", "ignore") # quiet joblib workers
import time
import warnings
from functools import partial
from pathlib import Path
import joblib
import mne
import numpy as np
import pandas as pd
from braindecode.datasets import BaseConcatDataset, RawDataset
from braindecode.preprocessing import create_fixed_length_windows
from eegdash.features import (
FeatureExtractor,
complexity_multiscale_entropy,
extract_features,
signal_variance,
spectral_bands_power,
spectral_preprocessor,
)
warnings.filterwarnings("ignore")
mne.set_log_level("ERROR")
np.random.seed(0)
CACHE = Path(os.environ.get("EEGDASH_FEATURES_CACHE", Path.cwd() / "feat_cache"))
CACHE.mkdir(parents=True, exist_ok=True)
N_CORES = os.cpu_count() or 1
Synthetic dataset (mimics plot_10 windows)#
16 short 6-channel resting-state recordings at 128 Hz; half get a 10 Hz alpha bump so the feature table is non-degenerate.
def _make_raw(seed: int, eyes_closed: bool, secs: int = 240) -> mne.io.Raw:
n = 128 * secs
data = np.random.default_rng(seed).standard_normal((6, n)) * 1e-6
if eyes_closed:
data += 4e-6 * np.sin(2 * np.pi * 10.0 * np.arange(n) / 128)
info = mne.create_info(["O1", "Oz", "O2", "Cz", "Pz", "POz"], 128, "eeg")
raw = mne.io.RawArray(data, info)
raw.filter(1.0, 40.0)
return raw
N_REC = 16
ds = BaseConcatDataset(
[
RawDataset(
_make_raw(i, bool(i % 2)),
target_name="target",
description={"subject": f"sub-{i:02d}", "target": int(i % 2)},
)
for i in range(N_REC)
]
)
windows = create_fixed_length_windows(
ds,
start_offset_samples=0,
stop_offset_samples=None,
window_size_samples=256,
window_stride_samples=256,
drop_last_window=True,
preload=True,
)
N_WIN = sum(len(d) for d in windows.datasets)
print(f"recordings={N_REC} windows={N_WIN} cores_available={N_CORES}")
Feature mix: multiscale sample entropy dominates and is what makes
parallelism pay; spectral_preprocessor is shared so the Welch PSD runs
once per window, not once per band.
Step 1 – profile the single-threaded baseline#
Run once with n_jobs=1 and assert no rows are silently dropped.
def _run(nj: int, bs: int = 64) -> float:
t = time.perf_counter()
out = extract_features(windows, features, batch_size=bs, n_jobs=nj)
df = out.to_dataframe(include_target=True)
assert df.shape[0] == N_WIN, "rows dropped silently"
return time.perf_counter() - t
t1 = _run(nj=1)
print(f"baseline n_jobs=1: {t1:.2f}s")
Step 2 – scale with n_jobs#
Read n_jobs from $SLURM_CPUS_PER_TASK (or your scheduler’s
equivalent). Never hard-code -1 on a shared node.
sched = int(os.environ.get("SLURM_CPUS_PER_TASK", min(4, N_CORES)))
sweep = sorted({1, 2, sched})
scaling = [{"n_jobs": 1, "wall_s": round(t1, 2), "speedup": 1.0}]
for nj in sweep:
if nj == 1:
continue
sec = _run(nj=nj)
scaling.append(
{"n_jobs": nj, "wall_s": round(sec, 2), "speedup": round(t1 / sec, 2)}
)
print(pd.DataFrame(scaling).to_string(index=False))
# On a 14-core macOS dev box: n_jobs=1 ~15.3 s, n_jobs=2 ~11.4 s (1.35x),
# n_jobs=4 ~8.2 s (1.86x). Joblib's process-spawn cost (~3 s on macOS,
# <1 s on Linux SLURM) eats part of the n_jobs=2 win on a laptop; on a
# dedicated SLURM node you typically reach >=1.5x at n_jobs=2 and the
# curve flattens once n_jobs equals the number of recordings.
Step 3 – tune batch_size#
batch_size controls how many windows each worker holds in memory.
Too small and Python per-batch overhead dominates; too large and you OOM.
batch_report = [
{"batch_size": bs, "wall_s": round(_run(sched, bs), 2)} for bs in (16, 64, 256)
]
print(pd.DataFrame(batch_report).to_string(index=False))
Step 4 – persist intermediate results#
Write the feature table to parquet once; reload on every subsequent call.
parquet = CACHE / "features.parquet"
if parquet.exists():
t = time.perf_counter()
df = pd.read_parquet(parquet)
reload_s = time.perf_counter() - t
else:
out = extract_features(windows, features, batch_size=64, n_jobs=sched)
df = out.to_dataframe(include_target=True)
df.to_parquet(parquet)
reload_s = float("nan")
assert parquet.exists() and len(df) == N_WIN
print(f"persisted={parquet.name} rows={len(df)} reload_s={reload_s:.3f}")
Step 5 (optional) – joblib.dump the extractor#
When the pipeline contains a fitted CSP or trainable feature, pickling the extractor lets you reapply it to held-out data without retraining.
extractor_path = CACHE / "extractor.joblib"
joblib.dump(features, extractor_path)
reused = joblib.load(extractor_path)
print(f"checkpoint -> {extractor_path.name} (keys={list(reused)})")
Common pitfalls#
Oversubscribed cores.
n_jobs=-1on a shared SLURM node steals cores from other jobs. Read$SLURM_CPUS_PER_TASK.Large batches OOMing. Each worker holds
batch_size * n_channels * window_samplesfloats plus spectra. On a 4 GB cgroup, keepbatch_size <= 256.Joblib caching gotchas.
joblib.Memorykeys on argument hashes; non-picklable lambdas infeaturesbreak caching. Usefunctools.partial()(as above) or named module-level functions.Process spawn cost. macOS
lokyspawns fresh Pythons (2-3 s fixed); Linuxforkis sub-second.
See also#
/auto_examples/tutorials/40_features/plot_41_feature_trees – shared-preprocessor pipelines that amplify the speedup here.
/auto_examples/how_to/how_to_use_hpc_cache – placing
EEGDASH_CACHEon local NVMe.
References#
Cisotto, G. and Chicco, D. (2024). Ten quick tips for clinical electroencephalographic (EEG) data acquisition and signal processing. PeerJ Computer Science, 10, e2256. doi:10.7717/peerj-cs.2256
Pedregosa, F. et al. (2011). Scikit-learn: Machine Learning in Python. JMLR, 12, 2825-2830; https://scikit-learn.org/stable/computing/parallelism.html