12

Optimizer Race: SGD vs Momentum vs Adam

click to drop a race · drag Y for learning rate

Three optimizers race down Rosenbrock's banana valley, — a narrow curved trough with global minimum at . Vanilla **SGD** () crawls along the floor and stalls early. **Momentum** (Polyak heavy-ball, ) builds up velocity and ploughs through the trough, but overshoots when the learning rate is high. **Adam** (, ) maintains a per-coordinate running variance and adapts its step size: it tends to escape the badly-scaled valley fastest and stays stable across a much wider range. Drag the mouse vertically to scrub the learning rate on a log scale — top of canvas is aggressive (watch SGD and Momentum fly off), bottom is timid (all three crawl). Click anywhere to drop a new race from that point. Auto-resets when all three converge, escape the box, or after 1000 steps.

idle
251 lines · vanilla
view source
// Rosenbrock's banana valley:  f(x,y) = (a - x)^2 + b*(y - x^2)^2
// with a = 1, b = 100.  Global min at (1, 1), f = 0.
// The valley is a narrow curved trough — classic optimizer torture test.
//
// Three balls descend from a shared starting point:
//   0  SGD       (vanilla)
//   1  Momentum  (Polyak, beta=0.9)
//   2  Adam      (b1=0.9, b2=0.999, eps=1e-8)
//
// Mouse Y scrubs the (log) learning rate.  Click drops a fresh race
// at the click position.  Auto-reset on convergence or after 1000 steps.

const A = 1, B = 100;
const XMIN = -2.2, XMAX = 2.2, YMIN = -1.2, YMAX = 3.2;
const MAX_STEPS = 1000;
const TRAIL_MAX = 900;

const NAMES   = ["SGD", "Momentum", "Adam"];
const COLORS  = ["#ff5577", "#55ddff", "#aaff66"];
const HUES    = [345, 195, 95];

// Loss range used for the heatmap colormap (log-compressed).
const LMAX = Math.log(1 + 2500);

let W = 0, H = 0;
let surfaceBuf = null, surfaceCtx = null, surfaceDirty = true;
let race = null;
let lr = 0.002;          // base learning rate (will be overridden by mouseY)
let lastMouseY = -1;

function init({ width, height }) {
  W = width; H = height;
  surfaceBuf = new OffscreenCanvas(W, H);
  surfaceCtx = surfaceBuf.getContext("2d");
  surfaceDirty = true;
  race = makeRace(-1.6, 2.2); // classic Rosenbrock start
}

// ---------- math ----------

function f(x, y) {
  const a = A - x;
  const b = y - x * x;
  return a * a + B * b * b;
}

function grad(x, y) {
  // df/dx = -2(a - x) - 4Bx(y - x^2)
  // df/dy =  2B(y - x^2)
  const gx = -2 * (A - x) - 4 * B * x * (y - x * x);
  const gy =  2 * B * (y - x * x);
  return [gx, gy];
}

// ---------- coords ----------

function worldToScreen(x, y) {
  return [((x - XMIN) / (XMAX - XMIN)) * W, ((YMAX - y) / (YMAX - YMIN)) * H];
}
function screenToWorld(px, py) {
  return [XMIN + (px / W) * (XMAX - XMIN), YMAX - (py / H) * (YMAX - YMIN)];
}

// ---------- heatmap ----------

function renderSurface() {
  // Render at canvas resolution. Rosenbrock is cheap (~6 mults/pixel).
  const img = surfaceCtx.createImageData(W, H);
  const d = img.data;
  for (let j = 0; j < H; j++) {
    const y = YMAX - (j / H) * (YMAX - YMIN);
    for (let i = 0; i < W; i++) {
      const x = XMIN + (i / W) * (XMAX - XMIN);
      const v = f(x, y);
      const t = Math.min(1, Math.log(1 + v) / LMAX);
      // Contour bands every ~1/14 of the log range.
      const m = (t * 14) % 1;
      const band = (m < 0.06) || (m > 0.94);
      const o = (j * W + i) * 4;
      // Cool deep-blue valley → warm orange ridge.
      const r = (28 + 220 * t) | 0;
      const g = (32 + 70 * (1 - Math.abs(t - 0.5) * 2)) | 0;
      const b = (95 - 70 * t) | 0;
      if (band) {
        d[o]     = 245;
        d[o + 1] = 235;
        d[o + 2] = 200;
      } else {
        d[o]     = r;
        d[o + 1] = g;
        d[o + 2] = b < 0 ? 0 : b;
      }
      d[o + 3] = 255;
    }
  }
  surfaceCtx.putImageData(img, 0, 0);
  surfaceDirty = false;
}

// ---------- race state ----------

function makeRunner(kind, x, y) {
  return {
    kind,            // 0 sgd / 1 momentum / 2 adam
    x, y,
    vx: 0, vy: 0,    // momentum buffers / Adam m
    sx: 0, sy: 0,    // Adam v
    t: 0,            // Adam time step
    steps: 0,
    loss: f(x, y),
    trail: [x, y],
    done: false,
    stalled: 0,      // # consecutive low-progress frames
    prevLoss: f(x, y),
  };
}

function makeRace(x0, y0) {
  return {
    x0, y0,
    runners: [makeRunner(0, x0, y0), makeRunner(1, x0, y0), makeRunner(2, x0, y0)],
    finished: false,
    cooldown: 0,
  };
}

// ---------- optimizer steps ----------

function clampStep(dx, dy, max) {
  const n = Math.hypot(dx, dy);
  if (n > max) { const k = max / n; return [dx * k, dy * k]; }
  return [dx, dy];
}

function stepRunner(r, eta) {
  if (r.done) return;
  if (r.steps >= MAX_STEPS) { r.done = true; return; }
  let [gx, gy] = grad(r.x, r.y);
  // Sanitize crazy gradients from numerical excursions.
  if (!isFinite(gx) || !isFinite(gy)) { r.done = true; return; }
  let dx = 0, dy = 0;
  if (r.kind === 0) {
    // Vanilla SGD.
    dx = -eta * gx;
    dy = -eta * gy;
  } else if (r.kind === 1) {
    // Polyak heavy-ball momentum, beta=0.9.
    r.vx = 0.9 * r.vx - eta * gx;
    r.vy = 0.9 * r.vy - eta * gy;
    dx = r.vx;
    dy = r.vy;
  } else {
    // Adam (b1=0.9, b2=0.999, eps=1e-8). Adam's effective per-coord step
    // is ~eta, so we scale by 10 to keep its lr-scale comparable with SGD's.
    const b1 = 0.9, b2 = 0.999, eps = 1e-8;
    r.t++;
    r.vx = b1 * r.vx + (1 - b1) * gx;
    r.vy = b1 * r.vy + (1 - b1) * gy;
    r.sx = b2 * r.sx + (1 - b2) * gx * gx;
    r.sy = b2 * r.sy + (1 - b2) * gy * gy;
    const bc1 = 1 - Math.pow(b1, r.t);
    const bc2 = 1 - Math.pow(b2, r.t);
    const mhx = r.vx / bc1, mhy = r.vy / bc1;
    const vhx = r.sx / bc2, vhy = r.sy / bc2;
    const etaA = eta * 10;
    dx = -etaA * mhx / (Math.sqrt(vhx) + eps);
    dy = -etaA * mhy / (Math.sqrt(vhy) + eps);
  }
  // Per-step cap so big learning rates produce "wild overshoots" you can see
  // but don't fling the runner instantly off the map.
  [dx, dy] = clampStep(dx, dy, 0.6);
  r.x += dx;
  r.y += dy;
  // Soft world clamp — a runner that escapes the box is parked at the edge
  // and marked done.
  if (r.x < XMIN || r.x > XMAX || r.y < YMIN || r.y > YMAX) {
    r.x = Math.max(XMIN, Math.min(XMAX, r.x));
    r.y = Math.max(YMIN, Math.min(YMAX, r.y));
    r.done = true;
  }
  r.steps++;
  const newLoss = f(r.x, r.y);
  // Stall / convergence: very small loss change across recent frames.
  if (Math.abs(newLoss - r.prevLoss) < 1e-6 && Math.hypot(dx, dy) < 5e-4) {
    r.stalled++;
    if (r.stalled > 30) r.done = true;
  } else {
    r.stalled = 0;
  }
  r.prevLoss = newLoss;
  r.loss = newLoss;
  // Trail.
  if ((r.steps & 1) === 0) {
    r.trail.push(r.x, r.y);
    if (r.trail.length > TRAIL_MAX) r.trail.splice(0, r.trail.length - TRAIL_MAX);
  }
  // Hard convergence to (1,1).
  if (Math.hypot(r.x - 1, r.y - 1) < 0.01 && Math.abs(newLoss) < 1e-3) {
    r.done = true;
  }
}

// ---------- drawing ----------

function drawTarget(ctx) {
  const [sx, sy] = worldToScreen(1, 1);
  ctx.strokeStyle = "rgba(255,255,255,0.9)";
  ctx.lineWidth = 1.5;
  ctx.beginPath(); ctx.arc(sx, sy, 8, 0, Math.PI * 2); ctx.stroke();
  ctx.beginPath();
  ctx.moveTo(sx - 4, sy); ctx.lineTo(sx + 4, sy);
  ctx.moveTo(sx, sy - 4); ctx.lineTo(sx, sy + 4);
  ctx.stroke();
}

function drawStart(ctx) {
  if (!race) return;
  const [sx, sy] = worldToScreen(race.x0, race.y0);
  ctx.strokeStyle = "rgba(255,255,255,0.5)";
  ctx.lineWidth = 1;
  ctx.beginPath(); ctx.arc(sx, sy, 4, 0, Math.PI * 2); ctx.stroke();
}

function drawRunners(ctx) {
  if (!race) return;
  for (let k = 0; k < race.runners.length; k++) {
    const r = race.runners[k];
    // Trail.
    ctx.strokeStyle = `hsla(${HUES[k]},95%,70%,0.95)`;
    ctx.lineWidth = 2;
    ctx.beginPath();
    for (let i = 0; i < r.trail.length; i += 2) {
      const [sx, sy] = worldToScreen(r.trail[i], r.trail[i + 1]);
      if (i === 0) ctx.moveTo(sx, sy); else ctx.lineTo(sx, sy);
    }
    const [hx, hy] = worldToScreen(r.x, r.y);
    ctx.lineTo(hx, hy);
    ctx.stroke();
    // Head.
    ctx.fillStyle = COLORS[k];
    ctx.beginPath(); ctx.arc(hx, hy, 5.5, 0, Math.PI * 2); ctx.fill();
    ctx.strokeStyle = "rgba(0,0,0,0.75)";
    ctx.lineWidth = 1;
    ctx.stroke();
  }
}

function drawHUD(ctx) {
  if (!race) return;
  const pad = 10;
  const lineH = 18;
  const rowH = 20;
  // Compute width based on canvas so HUD shrinks on phones.
  const wHud = Math.min(W - pad * 2, 260);
  const hHud = 28 + rowH * 3 + 22;
  ctx.fillStyle = "rgba(0,0,0,0.62)";
  ctx.fillRect(pad, pad, wHud, hHud);

  ctx.fillStyle = "#fff";
  ctx.font = "12px monospace";
  ctx.textAlign = "left";
  ctx.textBaseline = "alphabetic";

  ctx.fillText(`Rosenbrock  lr = ${lr.toFixed(4)}`, pad + 8, pad + 16);

  for (let k = 0; k < race.runners.length; k++) {
    const r = race.runners[k];
    const y = pad + 28 + k * rowH + 14;
    // Color swatch.
    ctx.fillStyle = COLORS[k];
    ctx.fillRect(pad + 8, y - 10, 10, 10);
    ctx.fillStyle = "#fff";
    const lossStr = r.loss < 1e-3 ? r.loss.toExponential(1)
                   : r.loss < 1000 ? r.loss.toFixed(3)
                   : r.loss.toExponential(1);
    const stat = r.done ? (Math.hypot(r.x - 1, r.y - 1) < 0.05 ? "WIN" : "OFF") : "   ";
    const name = NAMES[k].padEnd(8);
    const stepStr = String(r.steps).padStart(4);
    ctx.fillText(`${name} ${stepStr}  ${lossStr.padStart(9)} ${stat}`, pad + 22, y);
  }

  ctx.fillStyle = "#bcd";
  ctx.font = "11px monospace";
  ctx.fillText("drag Y: lr   click: new race", pad + 8, pad + 28 + rowH * 3 + 14);
}

// ---------- frame ----------

function tick({ ctx, width, height, input }) {
  if (width !== W || height !== H) { W = width; H = height; surfaceDirty = true; }
  if (surfaceDirty) renderSurface();

  // mouseY → log-scale learning rate. Top of canvas = high, bottom = low.
  // Only update when the mouse has actually been moved (not 0 at boot).
  const my = input.mouseY;
  if (my >= 0 && my <= H && my !== lastMouseY) {
    lastMouseY = my;
    const t = 1 - my / H;                 // 0 at bottom, 1 at top
    // Range: 5e-5 .. 0.02   (log-spaced)
    const lo = Math.log(5e-5), hi = Math.log(0.02);
    lr = Math.exp(lo + t * (hi - lo));
  }

  // Click → drop new race at click position.
  for (const c of input.consumeClicks()) {
    const [wx, wy] = screenToWorld(c.x, c.y);
    race = makeRace(wx, wy);
  }

  // Step the race.
  if (race && !race.finished) {
    for (const r of race.runners) stepRunner(r, lr);
    if (race.runners.every(r => r.done)) {
      race.finished = true;
      race.cooldown = 80; // ~1.3s at 60fps
    }
  } else if (race && race.finished) {
    race.cooldown--;
    if (race.cooldown <= 0) {
      // Auto-reset to the classic start.
      race = makeRace(-1.6, 2.2);
    }
  }

  ctx.drawImage(surfaceBuf, 0, 0);
  drawStart(ctx);
  drawTarget(ctx);
  drawRunners(ctx);
  drawHUD(ctx);
}

Comments (0)

Log in to comment.