17
Metropolis-Hastings: 2D Random Walk
drag Y to scrub step size · click to teleport
idle
248 lines · vanilla
view source
// Metropolis-Hastings on a 2D target.
// Target is a banana-shaped (Rosenbrock-style) density. We render its
// log-density as a faint heatmap, then run an MH chain on it with a
// Gaussian random-walk proposal. mouseY scrubs the proposal step size
// from tiny (~0.05) to huge (~3), so you can watch the acceptance rate
// cross the optimal ~0.234 for high-dim random-walk MH.
//
// Click to teleport the chain to a fresh random starting point.
// If untouched for a while we auto-cycle the step size, then teleport.
let W = 0, H = 0;
// Plot-space (data-space) extents. Banana lives roughly in
// x in [-3, 3], y in [-1.5, 5].
const X0 = -3.2, X1 = 3.2;
const Y0 = -1.8, Y1 = 5.2;
// Background heatmap (cached, recomputed on resize).
let heat = null; // ImageData
let heatW = 0, heatH = 0;
let heatPx = null; // Uint8ClampedArray
// Chain state.
let cx = 0, cy = 0; // current sample
let cLogp = -Infinity; // current log-density
// Ring buffer for trail samples — avoids per-frame allocations and O(n) shift().
const TRAIL_MAX = 600;
const trailX = new Float32Array(TRAIL_MAX);
const trailY = new Float32Array(TRAIL_MAX);
let trailHead = 0; // index where next sample will be written
let trailCount = 0; // number of valid samples, up to TRAIL_MAX
// Step size — scrubbed via mouseY.
let stepSize = 0.6; // sensible default
let lastUserMy = -1;
let autoCycleT = 0; // seconds since last user input
const AUTO_DELAY = 2.5; // start auto-cycling after this many idle seconds
// Acceptance stats.
let proposals = 0;
let accepts = 0;
// EWMA acceptance rate for the "live" readout (responsive to step changes).
let liveAcc = 0.234;
const EWMA_ALPHA = 0.04;
// Per-frame: how many MH steps to attempt. Higher = livelier trail.
const STEPS_PER_FRAME = 12;
// Teleport bookkeeping.
let teleportFlash = 0; // 0..1, fades to 0
let lastTeleportT = 0;
let timeSec = 0;
// ---- target density ----------------------------------------------------
// Rosenbrock-style "banana":
// logp(x, y) = -[ x^2 / (2*sx^2) + (y - x^2)^2 / (2*sy^2) ]
// with sx=1, sy=0.5. This is the classic MH benchmark.
const SX = 1.0;
const SY = 0.5;
function logp(x, y) {
const a = x;
const b = y - x * x;
return -(a * a) / (2 * SX * SX) - (b * b) / (2 * SY * SY);
}
// ---- box-muller --------------------------------------------------------
function randn() {
let u = Math.random();
let v = Math.random();
if (u < 1e-12) u = 1e-12;
return Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * v);
}
// ---- coord transforms --------------------------------------------------
function xToPx(x) {
return ((x - X0) / (X1 - X0)) * W;
}
function yToPx(y) {
// y axis is inverted on screen (positive y up in data, down on screen).
return H - ((y - Y0) / (Y1 - Y0)) * H;
}
// ---- background heatmap ------------------------------------------------
function buildHeatmap() {
// Render at lower resolution then we'll let putImageData scale via a
// temp OffscreenCanvas if needed. We render at full res for crispness.
heatW = Math.max(1, Math.floor(W));
heatH = Math.max(1, Math.floor(H));
heatPx = new Uint8ClampedArray(heatW * heatH * 4);
// Compute log-density at each pixel, normalize.
// logp max is 0 (at x=0, y=0).
let minL = 0;
// Sweep once to find a reasonable lower bound (clip very low tails).
// We use a fixed clip at -8 (well into the tail) for stable contrast.
const CLIP = -8;
for (let py = 0; py < heatH; py++) {
const y = Y0 + (1 - py / heatH) * (Y1 - Y0);
for (let px = 0; px < heatW; px++) {
const x = X0 + (px / heatW) * (X1 - X0);
let lp = logp(x, y);
if (lp < CLIP) lp = CLIP;
// Normalize: 0 (mode) -> bright, CLIP -> dark.
const t = (lp - CLIP) / (0 - CLIP); // 0..1
// Faint colormap — deep blue/violet, very subdued.
// We want the chain to read clearly on top.
const tt = t * t; // gamma for more contrast at the ridge
const r = Math.floor(20 + tt * 70);
const g = Math.floor(15 + tt * 35);
const b = Math.floor(35 + tt * 130);
const idx = (py * heatW + px) * 4;
heatPx[idx] = r;
heatPx[idx + 1] = g;
heatPx[idx + 2] = b;
heatPx[idx + 3] = 255;
}
}
heat = new ImageData(heatPx, heatW, heatH);
}
// ---- chain ops ---------------------------------------------------------
function teleport() {
// Pick a starting point uniformly in the plot box. Most will be in the
// tails, so the chain has to find the ridge — visually striking.
cx = X0 + Math.random() * (X1 - X0);
cy = Y0 + Math.random() * (Y1 - Y0);
cLogp = logp(cx, cy);
trailHead = 0;
trailCount = 0;
trailX[trailHead] = cx; trailY[trailHead] = cy;
trailHead = 1; trailCount = 1;
proposals = 0;
accepts = 0;
liveAcc = 0.234;
teleportFlash = 1;
lastTeleportT = timeSec;
}
function mhStep() {
const px = cx + stepSize * randn();
const py = cy + stepSize * randn();
const pLogp = logp(px, py);
// Accept with min(1, p'/p) = min(1, exp(lp' - lp)).
const logA = pLogp - cLogp;
const accept = logA >= 0 || Math.random() < Math.exp(logA);
proposals++;
if (accept) {
accepts++;
cx = px; cy = py; cLogp = pLogp;
}
trailX[trailHead] = cx;
trailY[trailHead] = cy;
trailHead = (trailHead + 1) % TRAIL_MAX;
if (trailCount < TRAIL_MAX) trailCount++;
// EWMA on per-proposal indicator (1 if accepted else 0).
liveAcc = liveAcc + EWMA_ALPHA * ((accept ? 1 : 0) - liveAcc);
}
// ---- step-size mapping -------------------------------------------------
// mouseY 0..1 (top..bottom) -> log step from ~0.03 to ~3.5.
const STEP_LO = 0.03;
const STEP_HI = 3.5;
function stepFromT(t) {
const lo = Math.log(STEP_LO);
const hi = Math.log(STEP_HI);
return Math.exp(lo + t * (hi - lo));
}
// ---- lifecycle ---------------------------------------------------------
function init({ width, height }) {
W = width; H = height;
buildHeatmap();
teleport();
// Burn-in a bit so first frame already has a trail.
for (let i = 0; i < 80; i++) mhStep();
}
function tick({ ctx, dt, width, height, input }) {
if (width !== W || height !== H) {
W = width; H = height;
buildHeatmap();
}
timeSec += dt;
// ---- input ----
const my = input.mouseY;
const userTouching = Number.isFinite(my) && my >= 0 && my <= H && my !== lastUserMy;
if (userTouching) {
autoCycleT = 0;
lastUserMy = my;
const t = Math.max(0, Math.min(1, my / H));
stepSize = stepFromT(t);
} else {
autoCycleT += dt;
if (autoCycleT > AUTO_DELAY) {
// Auto-cycle the step size with a slow sine sweep, and occasionally
// teleport so the visual stays alive.
const u = autoCycleT - AUTO_DELAY;
const t = 0.5 + 0.48 * Math.sin(u * 0.35);
stepSize = stepFromT(t);
if (timeSec - lastTeleportT > 9) teleport();
}
}
// consumeClicks() returns an array of click events; coerce to a count.
const clicks = input.consumeClicks ? input.consumeClicks() : null;
const clickCount = clicks && clicks.length ? clicks.length : 0;
if (clickCount > 0) {
teleport();
autoCycleT = 0;
}
// ---- MH steps ----
for (let i = 0; i < STEPS_PER_FRAME; i++) mhStep();
// ---- draw background heatmap ----
if (heat) {
ctx.putImageData(heat, 0, 0);
} else {
ctx.fillStyle = "#0a0a18";
ctx.fillRect(0, 0, W, H);
}
// Subtle dark overlay so chain pops more.
ctx.fillStyle = "rgba(0,0,0,0.18)";
ctx.fillRect(0, 0, W, H);
// ---- trail ----
// Draw as connected segments with alpha proportional to recency.
// Walk the ring buffer from oldest to newest.
const n = trailCount;
if (n >= 2) {
const oldestIdx = (trailHead - n + TRAIL_MAX) % TRAIL_MAX;
let prevI = oldestIdx;
for (let k = 1; k < n; k++) {
const curI = (oldestIdx + k) % TRAIL_MAX;
const age = k / n; // 0=old, 1=newest
const alpha = 0.05 + 0.55 * age;
ctx.strokeStyle = `rgba(120,220,255,${alpha.toFixed(3)})`;
ctx.lineWidth = 1 + age * 0.8;
ctx.beginPath();
ctx.moveTo(xToPx(trailX[prevI]), yToPx(trailY[prevI]));
ctx.lineTo(xToPx(trailX[curI]), yToPx(trailY[curI]));
ctx.stroke();
prevI = curI;
}
ctx.lineWidth = 1;
}
// ---- head dot ----
const hx = xToPx(cx), hy = yToPx(cy);
// Soft glow.
const glowR = 14;
const grad = ctx.createRadialGradient(hx, hy, 0, hx, hy, glowR);
grad.addColorStop(0, "rgba(255,255,255,0.9)");
grad.addColorStop(0.4, "rgba(180,230,255,0.4)");
grad.addColorStop(1, "rgba(120,200,255,0)");
ctx.fillStyle = grad;
ctx.beginPath();
ctx.arc(hx, hy, glowR, 0, Math.PI * 2);
ctx.fill();
// Solid head.
ctx.fillStyle = "#ffffff";
ctx.beginPath();
ctx.arc(hx, hy, 3.5, 0, Math.PI * 2);
ctx.fill();
// Teleport flash — quick ring at the new starting point.
if (teleportFlash > 0) {
const r = (1 - teleportFlash) * 40 + 6;
ctx.strokeStyle = `rgba(255,220,120,${teleportFlash.toFixed(3)})`;
ctx.lineWidth = 2;
ctx.beginPath();
ctx.arc(hx, hy, r, 0, Math.PI * 2);
ctx.stroke();
ctx.lineWidth = 1;
teleportFlash = Math.max(0, teleportFlash - dt * 1.6);
}
// ---- step-size slider (right edge) ----
const sliderX = W - 14;
ctx.strokeStyle = "rgba(255,255,255,0.18)";
ctx.beginPath();
ctx.moveTo(sliderX, 8);
ctx.lineTo(sliderX, H - 8);
ctx.stroke();
// Mark optimal step (~the value that yields 0.234 acc rate). For this
// banana that's roughly stepSize ~ 0.7-0.9; we just mark t=0.55 as a
// visual reference rather than computing exactly.
const tCur = (Math.log(stepSize) - Math.log(STEP_LO)) /
(Math.log(STEP_HI) - Math.log(STEP_LO));
const sliderY = 8 + tCur * (H - 16);
ctx.fillStyle = "rgba(120,220,255,0.95)";
ctx.beginPath();
ctx.arc(sliderX, sliderY, 5, 0, Math.PI * 2);
ctx.fill();
ctx.fillStyle = "rgba(255,255,255,0.5)";
ctx.font = "10px monospace";
ctx.fillText("tiny", sliderX - 28, 14);
ctx.fillText("huge", sliderX - 30, H - 4);
// ---- HUD ----
ctx.fillStyle = "rgba(0,0,0,0.55)";
ctx.fillRect(8, 8, 240, 92);
ctx.fillStyle = "#e8e8f0";
ctx.font = "bold 14px monospace";
ctx.fillText("Metropolis-Hastings", 16, 26);
ctx.font = "11px monospace";
ctx.fillStyle = "#aab";
ctx.fillText(`target: banana (Rosenbrock)`, 16, 42);
// Live acceptance rate. Highlight when near 0.234.
const cumAcc = proposals > 0 ? accepts / proposals : 0;
const near = Math.abs(liveAcc - 0.234) < 0.04;
ctx.fillStyle = near ? "#9f9" : "#9cf";
ctx.fillText(`accept (live): ${(liveAcc * 100).toFixed(1)}%`, 16, 60);
ctx.fillStyle = "#9cf";
ctx.fillText(`accept (cum): ${(cumAcc * 100).toFixed(1)}% (${accepts}/${proposals})`, 16, 76);
ctx.fillStyle = "#fc9";
ctx.fillText(`step σ: ${stepSize.toFixed(3)}`, 16, 92);
// Footer hint.
ctx.fillStyle = "rgba(180,180,200,0.7)";
ctx.font = "10px monospace";
ctx.fillText("drag Y to scrub step · click to teleport · optimal ≈ 0.234", 16, H - 8);
// Tiny tick at the 0.234 target on the acceptance dial — draw a small
// bar near the slider showing live acceptance for at-a-glance feedback.
const dialX = W - 32;
const dialY0 = 8, dialY1 = H - 8;
ctx.strokeStyle = "rgba(255,255,255,0.12)";
ctx.beginPath();
ctx.moveTo(dialX, dialY0);
ctx.lineTo(dialX, dialY1);
ctx.stroke();
// 0.234 mark
const optY = dialY0 + (1 - 0.234) * (dialY1 - dialY0);
ctx.strokeStyle = "rgba(150,255,150,0.6)";
ctx.beginPath();
ctx.moveTo(dialX - 4, optY);
ctx.lineTo(dialX + 4, optY);
ctx.stroke();
// current live acc
const accY = dialY0 + (1 - Math.max(0, Math.min(1, liveAcc))) * (dialY1 - dialY0);
ctx.fillStyle = near ? "rgba(150,255,150,0.95)" : "rgba(180,200,255,0.9)";
ctx.beginPath();
ctx.arc(dialX, accY, 3.5, 0, Math.PI * 2);
ctx.fill();
}
Comments (0)
Log in to comment.