12
Optimizer Race: SGD vs Momentum vs Adam
click to drop a race · drag Y for learning rate
idle
251 lines · vanilla
view source
// Rosenbrock's banana valley: f(x,y) = (a - x)^2 + b*(y - x^2)^2
// with a = 1, b = 100. Global min at (1, 1), f = 0.
// The valley is a narrow curved trough — classic optimizer torture test.
//
// Three balls descend from a shared starting point:
// 0 SGD (vanilla)
// 1 Momentum (Polyak, beta=0.9)
// 2 Adam (b1=0.9, b2=0.999, eps=1e-8)
//
// Mouse Y scrubs the (log) learning rate. Click drops a fresh race
// at the click position. Auto-reset on convergence or after 1000 steps.
const A = 1, B = 100;
const XMIN = -2.2, XMAX = 2.2, YMIN = -1.2, YMAX = 3.2;
const MAX_STEPS = 1000;
const TRAIL_MAX = 900;
const NAMES = ["SGD", "Momentum", "Adam"];
const COLORS = ["#ff5577", "#55ddff", "#aaff66"];
const HUES = [345, 195, 95];
// Loss range used for the heatmap colormap (log-compressed).
const LMAX = Math.log(1 + 2500);
let W = 0, H = 0;
let surfaceBuf = null, surfaceCtx = null, surfaceDirty = true;
let race = null;
let lr = 0.002; // base learning rate (will be overridden by mouseY)
let lastMouseY = -1;
function init({ width, height }) {
W = width; H = height;
surfaceBuf = new OffscreenCanvas(W, H);
surfaceCtx = surfaceBuf.getContext("2d");
surfaceDirty = true;
race = makeRace(-1.6, 2.2); // classic Rosenbrock start
}
// ---------- math ----------
function f(x, y) {
const a = A - x;
const b = y - x * x;
return a * a + B * b * b;
}
function grad(x, y) {
// df/dx = -2(a - x) - 4Bx(y - x^2)
// df/dy = 2B(y - x^2)
const gx = -2 * (A - x) - 4 * B * x * (y - x * x);
const gy = 2 * B * (y - x * x);
return [gx, gy];
}
// ---------- coords ----------
function worldToScreen(x, y) {
return [((x - XMIN) / (XMAX - XMIN)) * W, ((YMAX - y) / (YMAX - YMIN)) * H];
}
function screenToWorld(px, py) {
return [XMIN + (px / W) * (XMAX - XMIN), YMAX - (py / H) * (YMAX - YMIN)];
}
// ---------- heatmap ----------
function renderSurface() {
// Render at canvas resolution. Rosenbrock is cheap (~6 mults/pixel).
const img = surfaceCtx.createImageData(W, H);
const d = img.data;
for (let j = 0; j < H; j++) {
const y = YMAX - (j / H) * (YMAX - YMIN);
for (let i = 0; i < W; i++) {
const x = XMIN + (i / W) * (XMAX - XMIN);
const v = f(x, y);
const t = Math.min(1, Math.log(1 + v) / LMAX);
// Contour bands every ~1/14 of the log range.
const m = (t * 14) % 1;
const band = (m < 0.06) || (m > 0.94);
const o = (j * W + i) * 4;
// Cool deep-blue valley → warm orange ridge.
const r = (28 + 220 * t) | 0;
const g = (32 + 70 * (1 - Math.abs(t - 0.5) * 2)) | 0;
const b = (95 - 70 * t) | 0;
if (band) {
d[o] = 245;
d[o + 1] = 235;
d[o + 2] = 200;
} else {
d[o] = r;
d[o + 1] = g;
d[o + 2] = b < 0 ? 0 : b;
}
d[o + 3] = 255;
}
}
surfaceCtx.putImageData(img, 0, 0);
surfaceDirty = false;
}
// ---------- race state ----------
function makeRunner(kind, x, y) {
return {
kind, // 0 sgd / 1 momentum / 2 adam
x, y,
vx: 0, vy: 0, // momentum buffers / Adam m
sx: 0, sy: 0, // Adam v
t: 0, // Adam time step
steps: 0,
loss: f(x, y),
trail: [x, y],
done: false,
stalled: 0, // # consecutive low-progress frames
prevLoss: f(x, y),
};
}
function makeRace(x0, y0) {
return {
x0, y0,
runners: [makeRunner(0, x0, y0), makeRunner(1, x0, y0), makeRunner(2, x0, y0)],
finished: false,
cooldown: 0,
};
}
// ---------- optimizer steps ----------
function clampStep(dx, dy, max) {
const n = Math.hypot(dx, dy);
if (n > max) { const k = max / n; return [dx * k, dy * k]; }
return [dx, dy];
}
function stepRunner(r, eta) {
if (r.done) return;
if (r.steps >= MAX_STEPS) { r.done = true; return; }
let [gx, gy] = grad(r.x, r.y);
// Sanitize crazy gradients from numerical excursions.
if (!isFinite(gx) || !isFinite(gy)) { r.done = true; return; }
let dx = 0, dy = 0;
if (r.kind === 0) {
// Vanilla SGD.
dx = -eta * gx;
dy = -eta * gy;
} else if (r.kind === 1) {
// Polyak heavy-ball momentum, beta=0.9.
r.vx = 0.9 * r.vx - eta * gx;
r.vy = 0.9 * r.vy - eta * gy;
dx = r.vx;
dy = r.vy;
} else {
// Adam (b1=0.9, b2=0.999, eps=1e-8). Adam's effective per-coord step
// is ~eta, so we scale by 10 to keep its lr-scale comparable with SGD's.
const b1 = 0.9, b2 = 0.999, eps = 1e-8;
r.t++;
r.vx = b1 * r.vx + (1 - b1) * gx;
r.vy = b1 * r.vy + (1 - b1) * gy;
r.sx = b2 * r.sx + (1 - b2) * gx * gx;
r.sy = b2 * r.sy + (1 - b2) * gy * gy;
const bc1 = 1 - Math.pow(b1, r.t);
const bc2 = 1 - Math.pow(b2, r.t);
const mhx = r.vx / bc1, mhy = r.vy / bc1;
const vhx = r.sx / bc2, vhy = r.sy / bc2;
const etaA = eta * 10;
dx = -etaA * mhx / (Math.sqrt(vhx) + eps);
dy = -etaA * mhy / (Math.sqrt(vhy) + eps);
}
// Per-step cap so big learning rates produce "wild overshoots" you can see
// but don't fling the runner instantly off the map.
[dx, dy] = clampStep(dx, dy, 0.6);
r.x += dx;
r.y += dy;
// Soft world clamp — a runner that escapes the box is parked at the edge
// and marked done.
if (r.x < XMIN || r.x > XMAX || r.y < YMIN || r.y > YMAX) {
r.x = Math.max(XMIN, Math.min(XMAX, r.x));
r.y = Math.max(YMIN, Math.min(YMAX, r.y));
r.done = true;
}
r.steps++;
const newLoss = f(r.x, r.y);
// Stall / convergence: very small loss change across recent frames.
if (Math.abs(newLoss - r.prevLoss) < 1e-6 && Math.hypot(dx, dy) < 5e-4) {
r.stalled++;
if (r.stalled > 30) r.done = true;
} else {
r.stalled = 0;
}
r.prevLoss = newLoss;
r.loss = newLoss;
// Trail.
if ((r.steps & 1) === 0) {
r.trail.push(r.x, r.y);
if (r.trail.length > TRAIL_MAX) r.trail.splice(0, r.trail.length - TRAIL_MAX);
}
// Hard convergence to (1,1).
if (Math.hypot(r.x - 1, r.y - 1) < 0.01 && Math.abs(newLoss) < 1e-3) {
r.done = true;
}
}
// ---------- drawing ----------
function drawTarget(ctx) {
const [sx, sy] = worldToScreen(1, 1);
ctx.strokeStyle = "rgba(255,255,255,0.9)";
ctx.lineWidth = 1.5;
ctx.beginPath(); ctx.arc(sx, sy, 8, 0, Math.PI * 2); ctx.stroke();
ctx.beginPath();
ctx.moveTo(sx - 4, sy); ctx.lineTo(sx + 4, sy);
ctx.moveTo(sx, sy - 4); ctx.lineTo(sx, sy + 4);
ctx.stroke();
}
function drawStart(ctx) {
if (!race) return;
const [sx, sy] = worldToScreen(race.x0, race.y0);
ctx.strokeStyle = "rgba(255,255,255,0.5)";
ctx.lineWidth = 1;
ctx.beginPath(); ctx.arc(sx, sy, 4, 0, Math.PI * 2); ctx.stroke();
}
function drawRunners(ctx) {
if (!race) return;
for (let k = 0; k < race.runners.length; k++) {
const r = race.runners[k];
// Trail.
ctx.strokeStyle = `hsla(${HUES[k]},95%,70%,0.95)`;
ctx.lineWidth = 2;
ctx.beginPath();
for (let i = 0; i < r.trail.length; i += 2) {
const [sx, sy] = worldToScreen(r.trail[i], r.trail[i + 1]);
if (i === 0) ctx.moveTo(sx, sy); else ctx.lineTo(sx, sy);
}
const [hx, hy] = worldToScreen(r.x, r.y);
ctx.lineTo(hx, hy);
ctx.stroke();
// Head.
ctx.fillStyle = COLORS[k];
ctx.beginPath(); ctx.arc(hx, hy, 5.5, 0, Math.PI * 2); ctx.fill();
ctx.strokeStyle = "rgba(0,0,0,0.75)";
ctx.lineWidth = 1;
ctx.stroke();
}
}
function drawHUD(ctx) {
if (!race) return;
const pad = 10;
const lineH = 18;
const rowH = 20;
// Compute width based on canvas so HUD shrinks on phones.
const wHud = Math.min(W - pad * 2, 260);
const hHud = 28 + rowH * 3 + 22;
ctx.fillStyle = "rgba(0,0,0,0.62)";
ctx.fillRect(pad, pad, wHud, hHud);
ctx.fillStyle = "#fff";
ctx.font = "12px monospace";
ctx.textAlign = "left";
ctx.textBaseline = "alphabetic";
ctx.fillText(`Rosenbrock lr = ${lr.toFixed(4)}`, pad + 8, pad + 16);
for (let k = 0; k < race.runners.length; k++) {
const r = race.runners[k];
const y = pad + 28 + k * rowH + 14;
// Color swatch.
ctx.fillStyle = COLORS[k];
ctx.fillRect(pad + 8, y - 10, 10, 10);
ctx.fillStyle = "#fff";
const lossStr = r.loss < 1e-3 ? r.loss.toExponential(1)
: r.loss < 1000 ? r.loss.toFixed(3)
: r.loss.toExponential(1);
const stat = r.done ? (Math.hypot(r.x - 1, r.y - 1) < 0.05 ? "WIN" : "OFF") : " ";
const name = NAMES[k].padEnd(8);
const stepStr = String(r.steps).padStart(4);
ctx.fillText(`${name} ${stepStr} ${lossStr.padStart(9)} ${stat}`, pad + 22, y);
}
ctx.fillStyle = "#bcd";
ctx.font = "11px monospace";
ctx.fillText("drag Y: lr click: new race", pad + 8, pad + 28 + rowH * 3 + 14);
}
// ---------- frame ----------
function tick({ ctx, width, height, input }) {
if (width !== W || height !== H) { W = width; H = height; surfaceDirty = true; }
if (surfaceDirty) renderSurface();
// mouseY → log-scale learning rate. Top of canvas = high, bottom = low.
// Only update when the mouse has actually been moved (not 0 at boot).
const my = input.mouseY;
if (my >= 0 && my <= H && my !== lastMouseY) {
lastMouseY = my;
const t = 1 - my / H; // 0 at bottom, 1 at top
// Range: 5e-5 .. 0.02 (log-spaced)
const lo = Math.log(5e-5), hi = Math.log(0.02);
lr = Math.exp(lo + t * (hi - lo));
}
// Click → drop new race at click position.
for (const c of input.consumeClicks()) {
const [wx, wy] = screenToWorld(c.x, c.y);
race = makeRace(wx, wy);
}
// Step the race.
if (race && !race.finished) {
for (const r of race.runners) stepRunner(r, lr);
if (race.runners.every(r => r.done)) {
race.finished = true;
race.cooldown = 80; // ~1.3s at 60fps
}
} else if (race && race.finished) {
race.cooldown--;
if (race.cooldown <= 0) {
// Auto-reset to the classic start.
race = makeRace(-1.6, 2.2);
}
}
ctx.drawImage(surfaceBuf, 0, 0);
drawStart(ctx);
drawTarget(ctx);
drawRunners(ctx);
drawHUD(ctx);
}
Comments (0)
Log in to comment.