4

Beta-Binomial Conjugate Update

click L/R to add tails/heads · drag Y for prior strength

Bayesian inference for a coin's bias , in real time. The prior is with controlled by mouseY (top = strong, ; bottom = weak, ). Each click on the LEFT half of the canvas records a tails (); each click on the RIGHT half records a heads (). Because the Beta family is conjugate to the Bernoulli likelihood, the posterior after observing heads and tails is simply , with density . The green curve is the posterior, the faint grey curve is the current prior (re-rendered live as you scrub ), and the dashed orange verticals mark the 95% credible interval . The solid green vertical marks the mode . The strip along the bottom shows your H/T sequence. Try: stack 10 heads in a row with a weak prior — the posterior shoots toward ; redo it with a strong symmetric prior and the posterior barely moves. That contrast is what "prior strength" actually means.

idle
278 lines · vanilla
view source
// Beta-Binomial conjugate update.
// Prior: Beta(alpha0, beta0) with alpha0 = beta0 set by mouseY (0.5..30).
// Each click on the LEFT half adds a tails observation (beta += 1).
// Each click on the RIGHT half adds a heads observation (alpha += 1).
// Posterior: Beta(alpha0 + H, beta0 + T) over the bias p in [0,1].
// Renders posterior density curve, faint prior curve, mode, 95% credible
// interval (dashed verticals), a strip of H/T markers, and live counts.

let W = 0, H = 0;
let alpha0 = 2, beta0 = 2; // prior strength (will be re-derived from mouseY each tick)
let heads = 0, tails = 0;
let history = []; // sequence of 'H' / 'T'
let lastEventAge = 1e9;
let lastSide = 0; // -1 left flash, +1 right flash
let initialized = false;

const GRID = 401; // density grid over [0,1]
const grid = new Float64Array(GRID); // posterior density (normalized)
const priorGrid = new Float64Array(GRID); // prior density (normalized)

// log-gamma for normalization (Stirling-ish via Lanczos)
function logGamma(z) {
  // Lanczos approximation
  const g = 7;
  const c = [
    0.99999999999980993, 676.5203681218851, -1259.1392167224028,
    771.32342877765313, -176.61502916214059, 12.507343278686905,
    -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7
  ];
  if (z < 0.5) {
    return Math.log(Math.PI / Math.sin(Math.PI * z)) - logGamma(1 - z);
  }
  z -= 1;
  let x = c[0];
  for (let i = 1; i < g + 2; i++) x += c[i] / (z + i);
  const t = z + g + 0.5;
  return 0.5 * Math.log(2 * Math.PI) + (z + 0.5) * Math.log(t) - t + Math.log(x);
}

function logBetaPdf(p, a, b) {
  if (p <= 0 || p >= 1) return -Infinity;
  // log B(a,b) = lgamma(a) + lgamma(b) - lgamma(a+b)
  const logBeta = logGamma(a) + logGamma(b) - logGamma(a + b);
  return (a - 1) * Math.log(p) + (b - 1) * Math.log(1 - p) - logBeta;
}

function fillGrid(target, a, b) {
  // compute density at each grid point; handle endpoints
  let maxLog = -Infinity;
  for (let i = 0; i < GRID; i++) {
    const p = (i + 0.5) / GRID; // cell centers to avoid p=0 or p=1
    const lp = logBetaPdf(p, a, b);
    target[i] = lp;
    if (lp > maxLog) maxLog = lp;
  }
  // exponentiate (subtract max for stability) then normalize so cells sum to 1/dx
  // We'll keep raw density values: density(p) = exp(lp).
  for (let i = 0; i < GRID; i++) target[i] = Math.exp(target[i] - maxLog);
  // Renormalize so it integrates to 1
  const dx = 1 / GRID;
  let s = 0;
  for (let i = 0; i < GRID; i++) s += target[i] * dx;
  if (s > 0) for (let i = 0; i < GRID; i++) target[i] /= s;
}

function credibleInterval95(densArr) {
  // Cumulative integration; find p values where CDF = 0.025 and 0.975.
  const dx = 1 / GRID;
  let c = 0;
  let lo = 0, hi = 1;
  let gotLo = false;
  for (let i = 0; i < GRID; i++) {
    c += densArr[i] * dx;
    if (!gotLo && c >= 0.025) {
      lo = (i + 0.5) / GRID;
      gotLo = true;
    }
    if (c >= 0.975) {
      hi = (i + 0.5) / GRID;
      break;
    }
  }
  return { lo, hi };
}

function init({ width, height }) {
  W = width; H = height;
  if (!initialized) {
    heads = 0; tails = 0; history = [];
    initialized = true;
  }
}

function reset() {
  heads = 0; tails = 0; history = []; lastEventAge = 1e9;
}

function tick({ ctx, dt, width, height, input }) {
  if (width !== W || height !== H) { W = width; H = height; }

  // Reset key
  if (input.justPressed && (input.justPressed("r") || input.justPressed("R"))) reset();

  // Prior strength from mouseY: top = strong (30), bottom = weak (0.5)
  // Clamp to canvas; if mouseY is undefined fall back to middle.
  const my = (input.mouseY != null) ? Math.max(0, Math.min(H, input.mouseY)) : H * 0.5;
  const tNorm = my / Math.max(1, H);
  // top -> strong; bottom -> weak
  const priorStrength = 0.5 + (1 - tNorm) * (30 - 0.5);
  alpha0 = priorStrength;
  beta0 = priorStrength;

  // Click handling: left half -> tails, right half -> heads.
  // consumeClicks() returns an array of {x,y,button}; use per-click x so
  // a quick L then R doesn't both register on whichever side the cursor
  // currently happens to sit.
  const clicks = (input.consumeClicks ? input.consumeClicks() : []);
  if (clicks.length > 0) {
    for (let i = 0; i < clicks.length; i++) {
      const cx = (clicks[i] && clicks[i].x != null)
        ? clicks[i].x
        : (input.mouseX != null ? input.mouseX : W * 0.5);
      if (cx < W * 0.5) {
        tails++;
        history.push("T");
        lastSide = -1;
      } else {
        heads++;
        history.push("H");
        lastSide = 1;
      }
    }
    lastEventAge = 0;
    if (history.length > 240) history = history.slice(history.length - 240);
  }
  lastEventAge += dt;

  // Posterior parameters
  const a = alpha0 + heads;
  const b = beta0 + tails;

  // Compute grids
  fillGrid(grid, a, b);
  fillGrid(priorGrid, alpha0, beta0);

  // Background
  ctx.fillStyle = "#0a0a14";
  ctx.fillRect(0, 0, W, H);

  // Layout
  const pad = 28;
  const stripH = 56; // bottom strip with H/T markers
  const plotX0 = pad;
  const plotX1 = W - pad;
  const plotY0 = 56;
  const plotY1 = H - stripH - 36;
  const pw = plotX1 - plotX0;
  const ph = plotY1 - plotY0;

  // Click-side hint bands (very faint, plus a flash on the recent click side)
  // Left band
  const flashL = (lastSide === -1) ? Math.max(0, 0.18 * (1 - lastEventAge / 0.4)) : 0;
  ctx.fillStyle = `rgba(255,120,140,${(0.04 + flashL).toFixed(3)})`;
  ctx.fillRect(0, 0, W * 0.5, H);
  const flashR = (lastSide === 1) ? Math.max(0, 0.18 * (1 - lastEventAge / 0.4)) : 0;
  ctx.fillStyle = `rgba(120,200,255,${(0.04 + flashR).toFixed(3)})`;
  ctx.fillRect(W * 0.5, 0, W * 0.5, H);

  // Center divider
  ctx.strokeStyle = "rgba(255,255,255,0.06)";
  ctx.lineWidth = 1;
  ctx.beginPath();
  ctx.moveTo(W * 0.5, 0);
  ctx.lineTo(W * 0.5, H);
  ctx.stroke();

  // Plot panel
  ctx.fillStyle = "#13131e";
  ctx.fillRect(plotX0, plotY0, pw, ph);

  // y-scale: max of posterior; cap at something so very strong priors don't explode
  let yMax = 0;
  for (let i = 0; i < GRID; i++) if (grid[i] > yMax) yMax = grid[i];
  for (let i = 0; i < GRID; i++) if (priorGrid[i] > yMax) yMax = priorGrid[i];
  if (yMax < 1) yMax = 1;
  yMax *= 1.08;

  const toX = (p) => plotX0 + p * pw;
  const toY = (d) => plotY1 - (d / yMax) * ph;

  // Grid lines at p = 0, .25, .5, .75, 1
  ctx.strokeStyle = "rgba(255,255,255,0.08)";
  ctx.lineWidth = 1;
  ctx.setLineDash([2, 4]);
  for (let k = 0; k <= 4; k++) {
    const p = k / 4;
    const x = toX(p);
    ctx.beginPath();
    ctx.moveTo(x, plotY0);
    ctx.lineTo(x, plotY1);
    ctx.stroke();
  }
  ctx.setLineDash([]);
  // baseline
  ctx.strokeStyle = "rgba(255,255,255,0.18)";
  ctx.beginPath();
  ctx.moveTo(plotX0, plotY1);
  ctx.lineTo(plotX1, plotY1);
  ctx.stroke();

  // Prior curve (faint)
  ctx.strokeStyle = "rgba(200,200,220,0.32)";
  ctx.fillStyle = "rgba(200,200,220,0.08)";
  ctx.lineWidth = 1.5;
  ctx.beginPath();
  for (let i = 0; i < GRID; i++) {
    const p = (i + 0.5) / GRID;
    const x = toX(p);
    const y = toY(priorGrid[i]);
    if (i === 0) ctx.moveTo(x, y);
    else ctx.lineTo(x, y);
  }
  ctx.stroke();

  // Posterior curve (filled)
  ctx.fillStyle = "rgba(120,255,170,0.22)";
  ctx.beginPath();
  ctx.moveTo(toX(0), plotY1);
  for (let i = 0; i < GRID; i++) {
    const p = (i + 0.5) / GRID;
    ctx.lineTo(toX(p), toY(grid[i]));
  }
  ctx.lineTo(toX(1), plotY1);
  ctx.closePath();
  ctx.fill();
  ctx.strokeStyle = "rgba(120,255,170,0.95)";
  ctx.lineWidth = 2;
  ctx.beginPath();
  for (let i = 0; i < GRID; i++) {
    const p = (i + 0.5) / GRID;
    const x = toX(p);
    const y = toY(grid[i]);
    if (i === 0) ctx.moveTo(x, y);
    else ctx.lineTo(x, y);
  }
  ctx.stroke();
  ctx.lineWidth = 1;

  // Mode and 95% CI
  let mode = null;
  if (a > 1 && b > 1) {
    mode = (a - 1) / (a + b - 2);
  } else if (a >= 1 && b >= 1 && (a > 1 || b > 1)) {
    mode = a > b ? 1 : 0; // edge mode
  } // else bimodal/unbounded — skip mode line
  const mean = a / (a + b);
  const { lo, hi } = credibleInterval95(grid);

  // CI verticals
  ctx.strokeStyle = "rgba(255,200,80,0.85)";
  ctx.setLineDash([5, 4]);
  ctx.lineWidth = 1.5;
  for (const xv of [lo, hi]) {
    const x = toX(xv);
    ctx.beginPath();
    ctx.moveTo(x, plotY0);
    ctx.lineTo(x, plotY1);
    ctx.stroke();
  }
  ctx.setLineDash([]);
  // CI label
  ctx.fillStyle = "rgba(255,200,80,0.95)";
  ctx.font = "11px monospace";
  ctx.fillText(`95% CI`, toX(lo) + 3, plotY0 + 12);

  // Mode vertical (solid)
  if (mode != null) {
    ctx.strokeStyle = "rgba(120,255,170,1)";
    ctx.lineWidth = 1.5;
    ctx.beginPath();
    ctx.moveTo(toX(mode), plotY0);
    ctx.lineTo(toX(mode), plotY1);
    ctx.stroke();
    ctx.fillStyle = "rgba(120,255,170,1)";
    ctx.font = "11px monospace";
    ctx.fillText(`mode = ${mode.toFixed(3)}`, toX(mode) + 3, plotY0 + 26);
  }

  // Axis ticks (p values)
  ctx.fillStyle = "#99a";
  ctx.font = "10px monospace";
  for (let k = 0; k <= 4; k++) {
    const p = k / 4;
    const x = toX(p);
    ctx.fillText(p.toFixed(2), x - 10, plotY1 + 12);
  }

  // ---- bottom strip: H/T history ----
  const sX0 = pad;
  const sX1 = W - pad;
  const sY = H - stripH - 6;
  const sH = stripH - 14;
  ctx.fillStyle = "#13131e";
  ctx.fillRect(sX0, sY, sX1 - sX0, sH);
  // Determine per-marker width
  const maxMarkers = Math.max(1, history.length);
  const availW = sX1 - sX0 - 8;
  const mw = Math.max(3, Math.min(14, availW / Math.max(40, maxMarkers)));
  const baseY = sY + sH / 2;
  for (let i = 0; i < history.length; i++) {
    const x = sX0 + 4 + i * mw + mw / 2;
    if (x > sX1 - 2) break;
    if (history[i] === "H") {
      ctx.fillStyle = "rgba(120,200,255,0.95)";
      ctx.beginPath();
      ctx.arc(x, baseY - 3, Math.min(4, mw * 0.4), 0, Math.PI * 2);
      ctx.fill();
    } else {
      ctx.fillStyle = "rgba(255,140,160,0.95)";
      ctx.fillRect(x - Math.min(4, mw * 0.4), baseY + 1, Math.min(8, mw * 0.8), Math.min(8, mw * 0.8));
    }
  }
  ctx.strokeStyle = "rgba(255,255,255,0.15)";
  ctx.strokeRect(sX0, sY, sX1 - sX0, sH);
  ctx.fillStyle = "#99a";
  ctx.font = "10px monospace";
  ctx.fillText(`observations: H=${heads}  T=${tails}  (n=${heads + tails})`, sX0, sY - 4);

  // ---- header / HUD ----
  ctx.fillStyle = "#e8e8f0";
  ctx.font = "bold 16px monospace";
  ctx.fillText("Beta-Binomial Conjugate Update", pad, 24);

  ctx.font = "12px monospace";
  ctx.fillStyle = "#aab";
  const aStr = a < 100 ? a.toFixed(2) : a.toFixed(1);
  const bStr = b < 100 ? b.toFixed(2) : b.toFixed(1);
  ctx.fillText(`posterior: Beta(α=${aStr}, β=${bStr})   mean=${mean.toFixed(3)}   95% CI=[${lo.toFixed(3)}, ${hi.toFixed(3)}]`, pad, 44);

  ctx.font = "11px monospace";
  ctx.fillStyle = "#9cf";
  const p0Str = alpha0.toFixed(2);
  ctx.fillText(`prior strength α₀=β₀=${p0Str}  (drag mouseY: top=strong, bottom=weak)`, pad, plotY0 - 6);

  // Side labels
  ctx.fillStyle = "rgba(255,140,160,0.85)";
  ctx.font = "bold 13px monospace";
  ctx.fillText("← click here for TAILS  (β+1)", pad + 4, plotY0 - 22);
  ctx.fillStyle = "rgba(120,200,255,0.9)";
  const rText = "click here for HEADS (α+1) →";
  // approximate measure for monospace 13px: ~7.8 px/char
  const rw = rText.length * 7.8;
  ctx.fillText(rText, W - pad - rw, plotY0 - 22);

  // Footer hint
  ctx.fillStyle = "#778";
  ctx.font = "10px monospace";
  ctx.fillText("click L/R to add tails/heads · move mouseY for prior strength · [R] reset", pad, H - 6);
}

Comments (0)

Log in to comment.