4
Beta-Binomial Conjugate Update
click L/R to add tails/heads · drag Y for prior strength
idle
278 lines · vanilla
view source
// Beta-Binomial conjugate update.
// Prior: Beta(alpha0, beta0) with alpha0 = beta0 set by mouseY (0.5..30).
// Each click on the LEFT half adds a tails observation (beta += 1).
// Each click on the RIGHT half adds a heads observation (alpha += 1).
// Posterior: Beta(alpha0 + H, beta0 + T) over the bias p in [0,1].
// Renders posterior density curve, faint prior curve, mode, 95% credible
// interval (dashed verticals), a strip of H/T markers, and live counts.
let W = 0, H = 0;
let alpha0 = 2, beta0 = 2; // prior strength (will be re-derived from mouseY each tick)
let heads = 0, tails = 0;
let history = []; // sequence of 'H' / 'T'
let lastEventAge = 1e9;
let lastSide = 0; // -1 left flash, +1 right flash
let initialized = false;
const GRID = 401; // density grid over [0,1]
const grid = new Float64Array(GRID); // posterior density (normalized)
const priorGrid = new Float64Array(GRID); // prior density (normalized)
// log-gamma for normalization (Stirling-ish via Lanczos)
function logGamma(z) {
// Lanczos approximation
const g = 7;
const c = [
0.99999999999980993, 676.5203681218851, -1259.1392167224028,
771.32342877765313, -176.61502916214059, 12.507343278686905,
-0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7
];
if (z < 0.5) {
return Math.log(Math.PI / Math.sin(Math.PI * z)) - logGamma(1 - z);
}
z -= 1;
let x = c[0];
for (let i = 1; i < g + 2; i++) x += c[i] / (z + i);
const t = z + g + 0.5;
return 0.5 * Math.log(2 * Math.PI) + (z + 0.5) * Math.log(t) - t + Math.log(x);
}
function logBetaPdf(p, a, b) {
if (p <= 0 || p >= 1) return -Infinity;
// log B(a,b) = lgamma(a) + lgamma(b) - lgamma(a+b)
const logBeta = logGamma(a) + logGamma(b) - logGamma(a + b);
return (a - 1) * Math.log(p) + (b - 1) * Math.log(1 - p) - logBeta;
}
function fillGrid(target, a, b) {
// compute density at each grid point; handle endpoints
let maxLog = -Infinity;
for (let i = 0; i < GRID; i++) {
const p = (i + 0.5) / GRID; // cell centers to avoid p=0 or p=1
const lp = logBetaPdf(p, a, b);
target[i] = lp;
if (lp > maxLog) maxLog = lp;
}
// exponentiate (subtract max for stability) then normalize so cells sum to 1/dx
// We'll keep raw density values: density(p) = exp(lp).
for (let i = 0; i < GRID; i++) target[i] = Math.exp(target[i] - maxLog);
// Renormalize so it integrates to 1
const dx = 1 / GRID;
let s = 0;
for (let i = 0; i < GRID; i++) s += target[i] * dx;
if (s > 0) for (let i = 0; i < GRID; i++) target[i] /= s;
}
function credibleInterval95(densArr) {
// Cumulative integration; find p values where CDF = 0.025 and 0.975.
const dx = 1 / GRID;
let c = 0;
let lo = 0, hi = 1;
let gotLo = false;
for (let i = 0; i < GRID; i++) {
c += densArr[i] * dx;
if (!gotLo && c >= 0.025) {
lo = (i + 0.5) / GRID;
gotLo = true;
}
if (c >= 0.975) {
hi = (i + 0.5) / GRID;
break;
}
}
return { lo, hi };
}
function init({ width, height }) {
W = width; H = height;
if (!initialized) {
heads = 0; tails = 0; history = [];
initialized = true;
}
}
function reset() {
heads = 0; tails = 0; history = []; lastEventAge = 1e9;
}
function tick({ ctx, dt, width, height, input }) {
if (width !== W || height !== H) { W = width; H = height; }
// Reset key
if (input.justPressed && (input.justPressed("r") || input.justPressed("R"))) reset();
// Prior strength from mouseY: top = strong (30), bottom = weak (0.5)
// Clamp to canvas; if mouseY is undefined fall back to middle.
const my = (input.mouseY != null) ? Math.max(0, Math.min(H, input.mouseY)) : H * 0.5;
const tNorm = my / Math.max(1, H);
// top -> strong; bottom -> weak
const priorStrength = 0.5 + (1 - tNorm) * (30 - 0.5);
alpha0 = priorStrength;
beta0 = priorStrength;
// Click handling: left half -> tails, right half -> heads.
// consumeClicks() returns an array of {x,y,button}; use per-click x so
// a quick L then R doesn't both register on whichever side the cursor
// currently happens to sit.
const clicks = (input.consumeClicks ? input.consumeClicks() : []);
if (clicks.length > 0) {
for (let i = 0; i < clicks.length; i++) {
const cx = (clicks[i] && clicks[i].x != null)
? clicks[i].x
: (input.mouseX != null ? input.mouseX : W * 0.5);
if (cx < W * 0.5) {
tails++;
history.push("T");
lastSide = -1;
} else {
heads++;
history.push("H");
lastSide = 1;
}
}
lastEventAge = 0;
if (history.length > 240) history = history.slice(history.length - 240);
}
lastEventAge += dt;
// Posterior parameters
const a = alpha0 + heads;
const b = beta0 + tails;
// Compute grids
fillGrid(grid, a, b);
fillGrid(priorGrid, alpha0, beta0);
// Background
ctx.fillStyle = "#0a0a14";
ctx.fillRect(0, 0, W, H);
// Layout
const pad = 28;
const stripH = 56; // bottom strip with H/T markers
const plotX0 = pad;
const plotX1 = W - pad;
const plotY0 = 56;
const plotY1 = H - stripH - 36;
const pw = plotX1 - plotX0;
const ph = plotY1 - plotY0;
// Click-side hint bands (very faint, plus a flash on the recent click side)
// Left band
const flashL = (lastSide === -1) ? Math.max(0, 0.18 * (1 - lastEventAge / 0.4)) : 0;
ctx.fillStyle = `rgba(255,120,140,${(0.04 + flashL).toFixed(3)})`;
ctx.fillRect(0, 0, W * 0.5, H);
const flashR = (lastSide === 1) ? Math.max(0, 0.18 * (1 - lastEventAge / 0.4)) : 0;
ctx.fillStyle = `rgba(120,200,255,${(0.04 + flashR).toFixed(3)})`;
ctx.fillRect(W * 0.5, 0, W * 0.5, H);
// Center divider
ctx.strokeStyle = "rgba(255,255,255,0.06)";
ctx.lineWidth = 1;
ctx.beginPath();
ctx.moveTo(W * 0.5, 0);
ctx.lineTo(W * 0.5, H);
ctx.stroke();
// Plot panel
ctx.fillStyle = "#13131e";
ctx.fillRect(plotX0, plotY0, pw, ph);
// y-scale: max of posterior; cap at something so very strong priors don't explode
let yMax = 0;
for (let i = 0; i < GRID; i++) if (grid[i] > yMax) yMax = grid[i];
for (let i = 0; i < GRID; i++) if (priorGrid[i] > yMax) yMax = priorGrid[i];
if (yMax < 1) yMax = 1;
yMax *= 1.08;
const toX = (p) => plotX0 + p * pw;
const toY = (d) => plotY1 - (d / yMax) * ph;
// Grid lines at p = 0, .25, .5, .75, 1
ctx.strokeStyle = "rgba(255,255,255,0.08)";
ctx.lineWidth = 1;
ctx.setLineDash([2, 4]);
for (let k = 0; k <= 4; k++) {
const p = k / 4;
const x = toX(p);
ctx.beginPath();
ctx.moveTo(x, plotY0);
ctx.lineTo(x, plotY1);
ctx.stroke();
}
ctx.setLineDash([]);
// baseline
ctx.strokeStyle = "rgba(255,255,255,0.18)";
ctx.beginPath();
ctx.moveTo(plotX0, plotY1);
ctx.lineTo(plotX1, plotY1);
ctx.stroke();
// Prior curve (faint)
ctx.strokeStyle = "rgba(200,200,220,0.32)";
ctx.fillStyle = "rgba(200,200,220,0.08)";
ctx.lineWidth = 1.5;
ctx.beginPath();
for (let i = 0; i < GRID; i++) {
const p = (i + 0.5) / GRID;
const x = toX(p);
const y = toY(priorGrid[i]);
if (i === 0) ctx.moveTo(x, y);
else ctx.lineTo(x, y);
}
ctx.stroke();
// Posterior curve (filled)
ctx.fillStyle = "rgba(120,255,170,0.22)";
ctx.beginPath();
ctx.moveTo(toX(0), plotY1);
for (let i = 0; i < GRID; i++) {
const p = (i + 0.5) / GRID;
ctx.lineTo(toX(p), toY(grid[i]));
}
ctx.lineTo(toX(1), plotY1);
ctx.closePath();
ctx.fill();
ctx.strokeStyle = "rgba(120,255,170,0.95)";
ctx.lineWidth = 2;
ctx.beginPath();
for (let i = 0; i < GRID; i++) {
const p = (i + 0.5) / GRID;
const x = toX(p);
const y = toY(grid[i]);
if (i === 0) ctx.moveTo(x, y);
else ctx.lineTo(x, y);
}
ctx.stroke();
ctx.lineWidth = 1;
// Mode and 95% CI
let mode = null;
if (a > 1 && b > 1) {
mode = (a - 1) / (a + b - 2);
} else if (a >= 1 && b >= 1 && (a > 1 || b > 1)) {
mode = a > b ? 1 : 0; // edge mode
} // else bimodal/unbounded — skip mode line
const mean = a / (a + b);
const { lo, hi } = credibleInterval95(grid);
// CI verticals
ctx.strokeStyle = "rgba(255,200,80,0.85)";
ctx.setLineDash([5, 4]);
ctx.lineWidth = 1.5;
for (const xv of [lo, hi]) {
const x = toX(xv);
ctx.beginPath();
ctx.moveTo(x, plotY0);
ctx.lineTo(x, plotY1);
ctx.stroke();
}
ctx.setLineDash([]);
// CI label
ctx.fillStyle = "rgba(255,200,80,0.95)";
ctx.font = "11px monospace";
ctx.fillText(`95% CI`, toX(lo) + 3, plotY0 + 12);
// Mode vertical (solid)
if (mode != null) {
ctx.strokeStyle = "rgba(120,255,170,1)";
ctx.lineWidth = 1.5;
ctx.beginPath();
ctx.moveTo(toX(mode), plotY0);
ctx.lineTo(toX(mode), plotY1);
ctx.stroke();
ctx.fillStyle = "rgba(120,255,170,1)";
ctx.font = "11px monospace";
ctx.fillText(`mode = ${mode.toFixed(3)}`, toX(mode) + 3, plotY0 + 26);
}
// Axis ticks (p values)
ctx.fillStyle = "#99a";
ctx.font = "10px monospace";
for (let k = 0; k <= 4; k++) {
const p = k / 4;
const x = toX(p);
ctx.fillText(p.toFixed(2), x - 10, plotY1 + 12);
}
// ---- bottom strip: H/T history ----
const sX0 = pad;
const sX1 = W - pad;
const sY = H - stripH - 6;
const sH = stripH - 14;
ctx.fillStyle = "#13131e";
ctx.fillRect(sX0, sY, sX1 - sX0, sH);
// Determine per-marker width
const maxMarkers = Math.max(1, history.length);
const availW = sX1 - sX0 - 8;
const mw = Math.max(3, Math.min(14, availW / Math.max(40, maxMarkers)));
const baseY = sY + sH / 2;
for (let i = 0; i < history.length; i++) {
const x = sX0 + 4 + i * mw + mw / 2;
if (x > sX1 - 2) break;
if (history[i] === "H") {
ctx.fillStyle = "rgba(120,200,255,0.95)";
ctx.beginPath();
ctx.arc(x, baseY - 3, Math.min(4, mw * 0.4), 0, Math.PI * 2);
ctx.fill();
} else {
ctx.fillStyle = "rgba(255,140,160,0.95)";
ctx.fillRect(x - Math.min(4, mw * 0.4), baseY + 1, Math.min(8, mw * 0.8), Math.min(8, mw * 0.8));
}
}
ctx.strokeStyle = "rgba(255,255,255,0.15)";
ctx.strokeRect(sX0, sY, sX1 - sX0, sH);
ctx.fillStyle = "#99a";
ctx.font = "10px monospace";
ctx.fillText(`observations: H=${heads} T=${tails} (n=${heads + tails})`, sX0, sY - 4);
// ---- header / HUD ----
ctx.fillStyle = "#e8e8f0";
ctx.font = "bold 16px monospace";
ctx.fillText("Beta-Binomial Conjugate Update", pad, 24);
ctx.font = "12px monospace";
ctx.fillStyle = "#aab";
const aStr = a < 100 ? a.toFixed(2) : a.toFixed(1);
const bStr = b < 100 ? b.toFixed(2) : b.toFixed(1);
ctx.fillText(`posterior: Beta(α=${aStr}, β=${bStr}) mean=${mean.toFixed(3)} 95% CI=[${lo.toFixed(3)}, ${hi.toFixed(3)}]`, pad, 44);
ctx.font = "11px monospace";
ctx.fillStyle = "#9cf";
const p0Str = alpha0.toFixed(2);
ctx.fillText(`prior strength α₀=β₀=${p0Str} (drag mouseY: top=strong, bottom=weak)`, pad, plotY0 - 6);
// Side labels
ctx.fillStyle = "rgba(255,140,160,0.85)";
ctx.font = "bold 13px monospace";
ctx.fillText("← click here for TAILS (β+1)", pad + 4, plotY0 - 22);
ctx.fillStyle = "rgba(120,200,255,0.9)";
const rText = "click here for HEADS (α+1) →";
// approximate measure for monospace 13px: ~7.8 px/char
const rw = rText.length * 7.8;
ctx.fillText(rText, W - pad - rw, plotY0 - 22);
// Footer hint
ctx.fillStyle = "#778";
ctx.font = "10px monospace";
ctx.fillText("click L/R to add tails/heads · move mouseY for prior strength · [R] reset", pad, H - 6);
}
Comments (0)
Log in to comment.