17

Metropolis-Hastings: 2D Random Walk

drag Y to scrub step size · click to teleport

A Markov chain wanders across a 2D banana-shaped target density (a Rosenbrock-style ridge). Each step proposes and accepts with probability — the classic **Metropolis-Hastings** rule. The faint heatmap behind the chain is the target's log-density; the bright trail is the last samples, fading with age; the white dot is the current state. Drag your cursor (or finger) up and down to scrub the proposal scale from (almost stuck, acceptance but the chain barely moves) to (huge leaps, almost all rejected). Watch the live acceptance rate in the HUD: theory says random-walk MH is most efficient when it crosses the **Roberts–Gelman–Gilks optimum at ** — the readout turns green when you're close. Click anywhere to teleport the chain to a fresh random starting point and watch it equilibrate onto the ridge. Idle for a moment and the step size auto-cycles so the geometry of bad vs. good proposals stays on screen.

idle
248 lines · vanilla
view source
// Metropolis-Hastings on a 2D target.
// Target is a banana-shaped (Rosenbrock-style) density. We render its
// log-density as a faint heatmap, then run an MH chain on it with a
// Gaussian random-walk proposal. mouseY scrubs the proposal step size
// from tiny (~0.05) to huge (~3), so you can watch the acceptance rate
// cross the optimal ~0.234 for high-dim random-walk MH.
//
// Click to teleport the chain to a fresh random starting point.
// If untouched for a while we auto-cycle the step size, then teleport.

let W = 0, H = 0;

// Plot-space (data-space) extents. Banana lives roughly in
// x in [-3, 3], y in [-1.5, 5].
const X0 = -3.2, X1 = 3.2;
const Y0 = -1.8, Y1 = 5.2;

// Background heatmap (cached, recomputed on resize).
let heat = null;       // ImageData
let heatW = 0, heatH = 0;
let heatPx = null;     // Uint8ClampedArray

// Chain state.
let cx = 0, cy = 0;          // current sample
let cLogp = -Infinity;       // current log-density
// Ring buffer for trail samples — avoids per-frame allocations and O(n) shift().
const TRAIL_MAX = 600;
const trailX = new Float32Array(TRAIL_MAX);
const trailY = new Float32Array(TRAIL_MAX);
let trailHead = 0;           // index where next sample will be written
let trailCount = 0;          // number of valid samples, up to TRAIL_MAX

// Step size — scrubbed via mouseY.
let stepSize = 0.6;          // sensible default
let lastUserMy = -1;
let autoCycleT = 0;          // seconds since last user input
const AUTO_DELAY = 2.5;      // start auto-cycling after this many idle seconds

// Acceptance stats.
let proposals = 0;
let accepts = 0;
// EWMA acceptance rate for the "live" readout (responsive to step changes).
let liveAcc = 0.234;
const EWMA_ALPHA = 0.04;

// Per-frame: how many MH steps to attempt. Higher = livelier trail.
const STEPS_PER_FRAME = 12;

// Teleport bookkeeping.
let teleportFlash = 0;       // 0..1, fades to 0
let lastTeleportT = 0;

let timeSec = 0;

// ---- target density ----------------------------------------------------
// Rosenbrock-style "banana":
//   logp(x, y) = -[ x^2 / (2*sx^2) + (y - x^2)^2 / (2*sy^2) ]
// with sx=1, sy=0.5. This is the classic MH benchmark.
const SX = 1.0;
const SY = 0.5;

function logp(x, y) {
  const a = x;
  const b = y - x * x;
  return -(a * a) / (2 * SX * SX) - (b * b) / (2 * SY * SY);
}

// ---- box-muller --------------------------------------------------------
function randn() {
  let u = Math.random();
  let v = Math.random();
  if (u < 1e-12) u = 1e-12;
  return Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * v);
}

// ---- coord transforms --------------------------------------------------
function xToPx(x) {
  return ((x - X0) / (X1 - X0)) * W;
}
function yToPx(y) {
  // y axis is inverted on screen (positive y up in data, down on screen).
  return H - ((y - Y0) / (Y1 - Y0)) * H;
}

// ---- background heatmap ------------------------------------------------
function buildHeatmap() {
  // Render at lower resolution then we'll let putImageData scale via a
  // temp OffscreenCanvas if needed. We render at full res for crispness.
  heatW = Math.max(1, Math.floor(W));
  heatH = Math.max(1, Math.floor(H));
  heatPx = new Uint8ClampedArray(heatW * heatH * 4);

  // Compute log-density at each pixel, normalize.
  // logp max is 0 (at x=0, y=0).
  let minL = 0;
  // Sweep once to find a reasonable lower bound (clip very low tails).
  // We use a fixed clip at -8 (well into the tail) for stable contrast.
  const CLIP = -8;
  for (let py = 0; py < heatH; py++) {
    const y = Y0 + (1 - py / heatH) * (Y1 - Y0);
    for (let px = 0; px < heatW; px++) {
      const x = X0 + (px / heatW) * (X1 - X0);
      let lp = logp(x, y);
      if (lp < CLIP) lp = CLIP;
      // Normalize: 0 (mode) -> bright, CLIP -> dark.
      const t = (lp - CLIP) / (0 - CLIP); // 0..1
      // Faint colormap — deep blue/violet, very subdued.
      // We want the chain to read clearly on top.
      const tt = t * t; // gamma for more contrast at the ridge
      const r = Math.floor(20 + tt * 70);
      const g = Math.floor(15 + tt * 35);
      const b = Math.floor(35 + tt * 130);
      const idx = (py * heatW + px) * 4;
      heatPx[idx] = r;
      heatPx[idx + 1] = g;
      heatPx[idx + 2] = b;
      heatPx[idx + 3] = 255;
    }
  }
  heat = new ImageData(heatPx, heatW, heatH);
}

// ---- chain ops ---------------------------------------------------------
function teleport() {
  // Pick a starting point uniformly in the plot box. Most will be in the
  // tails, so the chain has to find the ridge — visually striking.
  cx = X0 + Math.random() * (X1 - X0);
  cy = Y0 + Math.random() * (Y1 - Y0);
  cLogp = logp(cx, cy);
  trailHead = 0;
  trailCount = 0;
  trailX[trailHead] = cx; trailY[trailHead] = cy;
  trailHead = 1; trailCount = 1;
  proposals = 0;
  accepts = 0;
  liveAcc = 0.234;
  teleportFlash = 1;
  lastTeleportT = timeSec;
}

function mhStep() {
  const px = cx + stepSize * randn();
  const py = cy + stepSize * randn();
  const pLogp = logp(px, py);
  // Accept with min(1, p'/p) = min(1, exp(lp' - lp)).
  const logA = pLogp - cLogp;
  const accept = logA >= 0 || Math.random() < Math.exp(logA);
  proposals++;
  if (accept) {
    accepts++;
    cx = px; cy = py; cLogp = pLogp;
  }
  trailX[trailHead] = cx;
  trailY[trailHead] = cy;
  trailHead = (trailHead + 1) % TRAIL_MAX;
  if (trailCount < TRAIL_MAX) trailCount++;
  // EWMA on per-proposal indicator (1 if accepted else 0).
  liveAcc = liveAcc + EWMA_ALPHA * ((accept ? 1 : 0) - liveAcc);
}

// ---- step-size mapping -------------------------------------------------
// mouseY 0..1 (top..bottom) -> log step from ~0.03 to ~3.5.
const STEP_LO = 0.03;
const STEP_HI = 3.5;
function stepFromT(t) {
  const lo = Math.log(STEP_LO);
  const hi = Math.log(STEP_HI);
  return Math.exp(lo + t * (hi - lo));
}

// ---- lifecycle ---------------------------------------------------------
function init({ width, height }) {
  W = width; H = height;
  buildHeatmap();
  teleport();
  // Burn-in a bit so first frame already has a trail.
  for (let i = 0; i < 80; i++) mhStep();
}

function tick({ ctx, dt, width, height, input }) {
  if (width !== W || height !== H) {
    W = width; H = height;
    buildHeatmap();
  }
  timeSec += dt;

  // ---- input ----
  const my = input.mouseY;
  const userTouching = Number.isFinite(my) && my >= 0 && my <= H && my !== lastUserMy;
  if (userTouching) {
    autoCycleT = 0;
    lastUserMy = my;
    const t = Math.max(0, Math.min(1, my / H));
    stepSize = stepFromT(t);
  } else {
    autoCycleT += dt;
    if (autoCycleT > AUTO_DELAY) {
      // Auto-cycle the step size with a slow sine sweep, and occasionally
      // teleport so the visual stays alive.
      const u = autoCycleT - AUTO_DELAY;
      const t = 0.5 + 0.48 * Math.sin(u * 0.35);
      stepSize = stepFromT(t);
      if (timeSec - lastTeleportT > 9) teleport();
    }
  }

  // consumeClicks() returns an array of click events; coerce to a count.
  const clicks = input.consumeClicks ? input.consumeClicks() : null;
  const clickCount = clicks && clicks.length ? clicks.length : 0;
  if (clickCount > 0) {
    teleport();
    autoCycleT = 0;
  }

  // ---- MH steps ----
  for (let i = 0; i < STEPS_PER_FRAME; i++) mhStep();

  // ---- draw background heatmap ----
  if (heat) {
    ctx.putImageData(heat, 0, 0);
  } else {
    ctx.fillStyle = "#0a0a18";
    ctx.fillRect(0, 0, W, H);
  }

  // Subtle dark overlay so chain pops more.
  ctx.fillStyle = "rgba(0,0,0,0.18)";
  ctx.fillRect(0, 0, W, H);

  // ---- trail ----
  // Draw as connected segments with alpha proportional to recency.
  // Walk the ring buffer from oldest to newest.
  const n = trailCount;
  if (n >= 2) {
    const oldestIdx = (trailHead - n + TRAIL_MAX) % TRAIL_MAX;
    let prevI = oldestIdx;
    for (let k = 1; k < n; k++) {
      const curI = (oldestIdx + k) % TRAIL_MAX;
      const age = k / n; // 0=old, 1=newest
      const alpha = 0.05 + 0.55 * age;
      ctx.strokeStyle = `rgba(120,220,255,${alpha.toFixed(3)})`;
      ctx.lineWidth = 1 + age * 0.8;
      ctx.beginPath();
      ctx.moveTo(xToPx(trailX[prevI]), yToPx(trailY[prevI]));
      ctx.lineTo(xToPx(trailX[curI]), yToPx(trailY[curI]));
      ctx.stroke();
      prevI = curI;
    }
    ctx.lineWidth = 1;
  }

  // ---- head dot ----
  const hx = xToPx(cx), hy = yToPx(cy);
  // Soft glow.
  const glowR = 14;
  const grad = ctx.createRadialGradient(hx, hy, 0, hx, hy, glowR);
  grad.addColorStop(0, "rgba(255,255,255,0.9)");
  grad.addColorStop(0.4, "rgba(180,230,255,0.4)");
  grad.addColorStop(1, "rgba(120,200,255,0)");
  ctx.fillStyle = grad;
  ctx.beginPath();
  ctx.arc(hx, hy, glowR, 0, Math.PI * 2);
  ctx.fill();
  // Solid head.
  ctx.fillStyle = "#ffffff";
  ctx.beginPath();
  ctx.arc(hx, hy, 3.5, 0, Math.PI * 2);
  ctx.fill();

  // Teleport flash — quick ring at the new starting point.
  if (teleportFlash > 0) {
    const r = (1 - teleportFlash) * 40 + 6;
    ctx.strokeStyle = `rgba(255,220,120,${teleportFlash.toFixed(3)})`;
    ctx.lineWidth = 2;
    ctx.beginPath();
    ctx.arc(hx, hy, r, 0, Math.PI * 2);
    ctx.stroke();
    ctx.lineWidth = 1;
    teleportFlash = Math.max(0, teleportFlash - dt * 1.6);
  }

  // ---- step-size slider (right edge) ----
  const sliderX = W - 14;
  ctx.strokeStyle = "rgba(255,255,255,0.18)";
  ctx.beginPath();
  ctx.moveTo(sliderX, 8);
  ctx.lineTo(sliderX, H - 8);
  ctx.stroke();
  // Mark optimal step (~the value that yields 0.234 acc rate). For this
  // banana that's roughly stepSize ~ 0.7-0.9; we just mark t=0.55 as a
  // visual reference rather than computing exactly.
  const tCur = (Math.log(stepSize) - Math.log(STEP_LO)) /
               (Math.log(STEP_HI) - Math.log(STEP_LO));
  const sliderY = 8 + tCur * (H - 16);
  ctx.fillStyle = "rgba(120,220,255,0.95)";
  ctx.beginPath();
  ctx.arc(sliderX, sliderY, 5, 0, Math.PI * 2);
  ctx.fill();
  ctx.fillStyle = "rgba(255,255,255,0.5)";
  ctx.font = "10px monospace";
  ctx.fillText("tiny", sliderX - 28, 14);
  ctx.fillText("huge", sliderX - 30, H - 4);

  // ---- HUD ----
  ctx.fillStyle = "rgba(0,0,0,0.55)";
  ctx.fillRect(8, 8, 240, 92);
  ctx.fillStyle = "#e8e8f0";
  ctx.font = "bold 14px monospace";
  ctx.fillText("Metropolis-Hastings", 16, 26);
  ctx.font = "11px monospace";
  ctx.fillStyle = "#aab";
  ctx.fillText(`target: banana (Rosenbrock)`, 16, 42);

  // Live acceptance rate. Highlight when near 0.234.
  const cumAcc = proposals > 0 ? accepts / proposals : 0;
  const near = Math.abs(liveAcc - 0.234) < 0.04;
  ctx.fillStyle = near ? "#9f9" : "#9cf";
  ctx.fillText(`accept (live):  ${(liveAcc * 100).toFixed(1)}%`, 16, 60);
  ctx.fillStyle = "#9cf";
  ctx.fillText(`accept (cum):   ${(cumAcc * 100).toFixed(1)}%   (${accepts}/${proposals})`, 16, 76);
  ctx.fillStyle = "#fc9";
  ctx.fillText(`step σ:         ${stepSize.toFixed(3)}`, 16, 92);

  // Footer hint.
  ctx.fillStyle = "rgba(180,180,200,0.7)";
  ctx.font = "10px monospace";
  ctx.fillText("drag Y to scrub step · click to teleport · optimal ≈ 0.234", 16, H - 8);

  // Tiny tick at the 0.234 target on the acceptance dial — draw a small
  // bar near the slider showing live acceptance for at-a-glance feedback.
  const dialX = W - 32;
  const dialY0 = 8, dialY1 = H - 8;
  ctx.strokeStyle = "rgba(255,255,255,0.12)";
  ctx.beginPath();
  ctx.moveTo(dialX, dialY0);
  ctx.lineTo(dialX, dialY1);
  ctx.stroke();
  // 0.234 mark
  const optY = dialY0 + (1 - 0.234) * (dialY1 - dialY0);
  ctx.strokeStyle = "rgba(150,255,150,0.6)";
  ctx.beginPath();
  ctx.moveTo(dialX - 4, optY);
  ctx.lineTo(dialX + 4, optY);
  ctx.stroke();
  // current live acc
  const accY = dialY0 + (1 - Math.max(0, Math.min(1, liveAcc))) * (dialY1 - dialY0);
  ctx.fillStyle = near ? "rgba(150,255,150,0.95)" : "rgba(180,200,255,0.9)";
  ctx.beginPath();
  ctx.arc(dialX, accY, 3.5, 0, Math.PI * 2);
  ctx.fill();
}

Comments (0)

Log in to comment.