8
Bayesian Linear Regression: Credible Bands
tap to add a point · drag Y for prior · clear/reseed buttons below
idle
287 lines · vanilla
view source
// Bayesian linear regression with credible bands.
//
// Model: y = m x + b + eps, eps ~ N(0, sigma^2)
// Prior: [m, b] ~ N(0, tau^2 I)
// Closed-form posterior over w = [m, b]:
// Sigma_post = ( X^T X / sigma^2 + I / tau^2 )^{-1}
// mu_post = Sigma_post X^T y / sigma^2
// Predictive variance at x: var(y*|x) = phi^T Sigma_post phi, phi=[x,1]
//
// Render: data, posterior-mean line, 95% credible band, faint posterior-sample lines.
// Interaction: click to add a point. mouseY scrubs prior precision 1/tau^2.
let W = 0, H = 0;
let pts = []; // {x, y} in data coords
let initialized = false;
let pulseAge = 0;
let frameCount = 0;
// Sampling RNG (deterministic so samples don't strobe every frame)
let sampleSeed = 1;
function rand01() {
// mulberry32
sampleSeed = (sampleSeed + 0x6D2B79F5) | 0;
let t = sampleSeed;
t = Math.imul(t ^ (t >>> 15), t | 1);
t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
}
function randn() {
const u1 = Math.max(1e-9, rand01());
const u2 = rand01();
return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
}
const X_MIN = 0, X_MAX = 10;
const Y_MIN = 0, Y_MAX = 10;
const SIGMA = 0.9; // observation noise std (fixed)
const TRUE_M = 0.55, TRUE_B = 2.1;
// prior precision 1/tau^2 in log10. mouseY scrubs this between weak and strong.
// weak prior: log10(1/tau^2) ≈ -3 (tau ≈ 31)
// strong prior: log10(1/tau^2) ≈ +1.5 (tau ≈ 0.18)
let logPriorPrec = -1.5; // default: moderately weak
function seedPoints() {
pts = [];
// Use a deterministic Box-Muller via Math.random for the *seed* data —
// the sim still varies on reload (we don't reseed Math).
for (let i = 0; i < 14; i++) {
const x = 0.6 + Math.random() * 8.8;
const u1 = Math.max(1e-9, Math.random()), u2 = Math.random();
const z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
const y = TRUE_M * x + TRUE_B + SIGMA * z;
pts.push({ x, y: Math.max(Y_MIN + 0.05, Math.min(Y_MAX - 0.05, y)) });
}
}
function init({ width, height }) {
W = width; H = height;
if (!initialized) { seedPoints(); initialized = true; }
}
function plotRect() {
const padL = W < 480 ? 38 : 50;
const padR = W < 480 ? 18 : 26;
const padTop = W < 480 ? 58 : 64;
// Extra room for the two readout lines + tap-target buttons.
const padBot = W < 480 ? 108 : 104;
return { x0: padL, y0: padTop, x1: W - padR, y1: H - padBot };
}
function dataToPx(x, y, r) {
const px = r.x0 + ((x - X_MIN) / (X_MAX - X_MIN)) * (r.x1 - r.x0);
const py = r.y1 - ((y - Y_MIN) / (Y_MAX - Y_MIN)) * (r.y1 - r.y0);
return [px, py];
}
function pxToData(px, py, r) {
const x = X_MIN + ((px - r.x0) / (r.x1 - r.x0)) * (X_MAX - X_MIN);
const y = Y_MIN + ((r.y1 - py) / (r.y1 - r.y0)) * (Y_MAX - Y_MIN);
return [x, y];
}
// Posterior over w=[m,b] given data, sigma, priorPrec = 1/tau^2.
// X has rows [x_i, 1]. Returns {muM, muB, S11, S12, S22, priorPrec, tau}
function posterior(priorPrec) {
// A = X^T X / sigma^2 + priorPrec * I (2x2)
// c = X^T y / sigma^2 (2-vector)
// Sigma_post = A^{-1}; mu_post = Sigma_post c
const s2 = SIGMA * SIGMA;
let sxx = 0, sx = 0, sn = 0, sxy = 0, sy = 0;
for (const p of pts) {
sxx += p.x * p.x; sx += p.x; sn += 1;
sxy += p.x * p.y; sy += p.y;
}
const A11 = sxx / s2 + priorPrec;
const A12 = sx / s2;
const A22 = sn / s2 + priorPrec;
const c1 = sxy / s2;
const c2 = sy / s2;
const det = A11 * A22 - A12 * A12;
const invDet = det > 1e-12 ? 1 / det : 0;
// Sigma_post = (1/det) * [[A22, -A12], [-A12, A11]]
const S11 = A22 * invDet;
const S12 = -A12 * invDet;
const S22 = A11 * invDet;
// mu = Sigma * c
const muM = S11 * c1 + S12 * c2;
const muB = S12 * c1 + S22 * c2;
return { muM, muB, S11, S12, S22, priorPrec, tau: 1 / Math.sqrt(priorPrec) };
}
// Predictive std for slope+intercept at x: sqrt(phi^T Sigma phi), phi=[x,1]
function predStd(post, x) {
const v = x * x * post.S11 + 2 * x * post.S12 + post.S22;
return Math.sqrt(Math.max(0, v));
}
// Sample (m, b) from N(mu, Sigma) using Cholesky of 2x2 Sigma.
function sampleW(post) {
// L L^T = Sigma: L11 = sqrt(S11); L21 = S12/L11; L22 = sqrt(S22 - L21^2)
const L11 = Math.sqrt(Math.max(0, post.S11));
const L21 = L11 > 1e-9 ? post.S12 / L11 : 0;
const L22 = Math.sqrt(Math.max(0, post.S22 - L21 * L21));
const z1 = randn(), z2 = randn();
return {
m: post.muM + L11 * z1,
b: post.muB + L21 * z1 + L22 * z2,
};
}
function tick({ ctx, dt, width, height, input }) {
if (width !== W || height !== H) { W = width; H = height; }
frameCount++;
if (input.justPressed("r") || input.justPressed("R")) { seedPoints(); }
if (input.justPressed("c") || input.justPressed("C")) { pts = []; }
const r = plotRect();
const mx = input.mouseX, my = input.mouseY;
const inPlot = mx >= r.x0 && mx <= r.x1 && my >= r.y0 && my <= r.y1;
// Tap-targets just below the plot. Sized for thumbs (≥44px tall on
// mobile) so keyboard-less users can still clear / reseed.
const btnH = W < 480 ? 30 : 26;
const btnGap = 8;
const btnY = r.y1 + (W < 480 ? 56 : 60);
const btnW = Math.min(120, (r.x1 - r.x0 - btnGap) / 2);
const clearBtn = { x: r.x0, y: btnY, w: btnW, h: btnH };
const reseedBtn = { x: r.x0 + btnW + btnGap, y: btnY, w: btnW, h: btnH };
// mouseY (any time the cursor is over the canvas, not just plot) scrubs prior.
// Top of canvas = strong prior (high precision); bottom = weak prior.
if (mx >= 0 && mx <= W && my >= 0 && my <= H) {
const t = Math.max(0, Math.min(1, my / Math.max(1, H)));
// top -> 1.5 (strong), bottom -> -3.0 (weak)
logPriorPrec = 1.5 + (-3.0 - 1.5) * t;
}
const priorPrec = Math.pow(10, logPriorPrec);
// Click to add a point inside the plot. consumeClicks returns an
// ARRAY of {x,y,button} — not a count — so check .length, and use
// the click's own coords (not the live cursor) so taps land where
// the user touched, even on mobile.
const clicks = input.consumeClicks ? input.consumeClicks() : [];
for (const c of clicks) {
// Tap-targets live in the footer band below the plot, full-width.
const inClearBtn =
c.y >= clearBtn.y && c.y <= clearBtn.y + clearBtn.h &&
c.x >= clearBtn.x && c.x <= clearBtn.x + clearBtn.w;
const inReseedBtn =
c.y >= reseedBtn.y && c.y <= reseedBtn.y + reseedBtn.h &&
c.x >= reseedBtn.x && c.x <= reseedBtn.x + reseedBtn.w;
if (inClearBtn) { pts = []; continue; }
if (inReseedBtn) { seedPoints(); continue; }
const inPlotClick =
c.x >= r.x0 && c.x <= r.x1 && c.y >= r.y0 && c.y <= r.y1;
if (inPlotClick) {
const [dx, dy] = pxToData(c.x, c.y, r);
pts.push({ x: dx, y: dy });
pulseAge = 0;
}
}
pulseAge += dt;
// Refresh sample RNG every ~0.6s so the faint lines slowly wiggle, not flicker.
if ((frameCount % 36) === 0) sampleSeed = (Date.now() & 0x7fffffff) | 1;
const post = posterior(priorPrec);
// -------- background --------
ctx.fillStyle = "#0a0a10";
ctx.fillRect(0, 0, W, H);
// plot bg
ctx.fillStyle = "#13131c";
ctx.fillRect(r.x0, r.y0, r.x1 - r.x0, r.y1 - r.y0);
// grid
ctx.strokeStyle = "rgba(255,255,255,0.06)";
ctx.lineWidth = 1;
for (let gx = 0; gx <= 10; gx++) {
const [px] = dataToPx(gx, 0, r);
ctx.beginPath(); ctx.moveTo(px, r.y0); ctx.lineTo(px, r.y1); ctx.stroke();
}
for (let gy = 0; gy <= 10; gy++) {
const [, py] = dataToPx(0, gy, r);
ctx.beginPath(); ctx.moveTo(r.x0, py); ctx.lineTo(r.x1, py); ctx.stroke();
}
// axis tick labels
ctx.fillStyle = "#667";
ctx.font = "10px monospace";
for (let gx = 0; gx <= 10; gx += 2) {
const [px] = dataToPx(gx, 0, r);
ctx.fillText(String(gx), px - 4, r.y1 + 14);
}
for (let gy = 0; gy <= 10; gy += 2) {
const [, py] = dataToPx(0, gy, r);
ctx.fillText(String(gy), r.x0 - 18, py + 4);
}
// -------- credible band (95%) --------
// Sweep x across the plot, compute y_hat = muM*x + muB and sigma_pred(x).
// Band edges at +-1.96*sigma_pred(x).
const NSWEEP = Math.max(40, Math.min(160, Math.floor((r.x1 - r.x0) / 4)));
ctx.save();
ctx.beginPath();
ctx.rect(r.x0, r.y0, r.x1 - r.x0, r.y1 - r.y0);
ctx.clip();
// upper edge then back along lower
ctx.beginPath();
for (let i = 0; i <= NSWEEP; i++) {
const x = X_MIN + (i / NSWEEP) * (X_MAX - X_MIN);
const yhat = post.muM * x + post.muB;
const s = predStd(post, x);
const yhi = yhat + 1.96 * s;
const [px, py] = dataToPx(x, yhi, r);
if (i === 0) ctx.moveTo(px, py); else ctx.lineTo(px, py);
}
for (let i = NSWEEP; i >= 0; i--) {
const x = X_MIN + (i / NSWEEP) * (X_MAX - X_MIN);
const yhat = post.muM * x + post.muB;
const s = predStd(post, x);
const ylo = yhat - 1.96 * s;
const [px, py] = dataToPx(x, ylo, r);
ctx.lineTo(px, py);
}
ctx.closePath();
ctx.fillStyle = "rgba(120,200,255,0.16)";
ctx.fill();
// posterior-sample lines (faint)
const NSAMP = 30;
ctx.strokeStyle = "rgba(180,220,255,0.10)";
ctx.lineWidth = 1;
for (let s = 0; s < NSAMP; s++) {
const w = sampleW(post);
const [x0p, y0p] = dataToPx(X_MIN, w.m * X_MIN + w.b, r);
const [x1p, y1p] = dataToPx(X_MAX, w.m * X_MAX + w.b, r);
ctx.beginPath();
ctx.moveTo(x0p, y0p); ctx.lineTo(x1p, y1p); ctx.stroke();
}
// true line (thin dashed gray) — pedagogical reference
ctx.setLineDash([4, 4]);
ctx.strokeStyle = "rgba(255,255,255,0.22)";
ctx.lineWidth = 1;
{
const [x0p, y0p] = dataToPx(X_MIN, TRUE_M * X_MIN + TRUE_B, r);
const [x1p, y1p] = dataToPx(X_MAX, TRUE_M * X_MAX + TRUE_B, r);
ctx.beginPath();
ctx.moveTo(x0p, y0p); ctx.lineTo(x1p, y1p); ctx.stroke();
}
ctx.setLineDash([]);
// posterior-mean line (solid blue)
ctx.strokeStyle = "rgba(120,200,255,1)";
ctx.lineWidth = 2;
{
const [x0p, y0p] = dataToPx(X_MIN, post.muM * X_MIN + post.muB, r);
const [x1p, y1p] = dataToPx(X_MAX, post.muM * X_MAX + post.muB, r);
ctx.beginPath();
ctx.moveTo(x0p, y0p); ctx.lineTo(x1p, y1p); ctx.stroke();
}
ctx.lineWidth = 1;
ctx.restore();
// -------- data points --------
for (let i = 0; i < pts.length; i++) {
const p = pts[i];
const [px, py] = dataToPx(p.x, p.y, r);
ctx.fillStyle = "rgba(230,230,240,0.95)";
ctx.beginPath();
ctx.arc(px, py, 3.5, 0, Math.PI * 2);
ctx.fill();
}
// pulse on most recent
if (pulseAge < 0.6 && pts.length > 0) {
const p = pts[pts.length - 1];
const [px, py] = dataToPx(p.x, p.y, r);
const rad = 5 + pulseAge * 30;
ctx.strokeStyle = `rgba(120,255,150,${(1 - pulseAge / 0.6).toFixed(3)})`;
ctx.lineWidth = 2;
ctx.beginPath();
ctx.arc(px, py, rad, 0, Math.PI * 2);
ctx.stroke();
ctx.lineWidth = 1;
}
// -------- header --------
ctx.fillStyle = "#e8e8f0";
ctx.font = `bold ${W < 480 ? 13 : 16}px monospace`;
ctx.fillText("Bayesian Linear Regression", r.x0, W < 480 ? 22 : 28);
ctx.font = `${W < 480 ? 10 : 11}px monospace`;
ctx.fillStyle = "#aab";
if (W >= 380) {
ctx.fillText("posterior mean (blue) · 95% credible band · 30 posterior samples", r.x0, W < 480 ? 38 : 46);
} else {
ctx.fillText("mean · band · samples", r.x0, 38);
}
// -------- readouts --------
ctx.font = `${W < 480 ? 11 : 13}px monospace`;
const lineY = r.y1 + 30;
const lineY2 = r.y1 + (W < 480 ? 46 : 50);
// posterior marginal stds
const stdM = Math.sqrt(Math.max(0, post.S11));
const stdB = Math.sqrt(Math.max(0, post.S22));
if (W < 480) {
ctx.fillStyle = "#9cf";
ctx.fillText(`m = ${post.muM.toFixed(3)} ± ${stdM.toFixed(3)}`, r.x0, lineY);
ctx.fillStyle = "#9cf";
ctx.fillText(`b = ${post.muB.toFixed(3)} ± ${stdB.toFixed(3)}`, r.x0, lineY2);
ctx.fillStyle = "#fc6";
ctx.fillText(`τ = ${post.tau.toFixed(2)} n = ${pts.length}`, r.x0 + 150, lineY);
} else {
ctx.fillStyle = "#9cf";
ctx.fillText(`m = ${post.muM.toFixed(3)} ± ${stdM.toFixed(3)}`, r.x0, lineY);
ctx.fillStyle = "#9cf";
ctx.fillText(`b = ${post.muB.toFixed(3)} ± ${stdB.toFixed(3)}`, r.x0 + 200, lineY);
ctx.fillStyle = "#fc6";
ctx.fillText(`τ = ${post.tau.toFixed(2)} (1/τ² = ${priorPrec.toFixed(3)})`, r.x0 + 400, lineY);
ctx.fillStyle = "#aab";
ctx.fillText(`n = ${pts.length} σ = ${SIGMA.toFixed(2)} true: m=${TRUE_M}, b=${TRUE_B}`, r.x0, lineY2);
}
// -------- prior strength gauge (right edge, top → strong, bot → weak) --------
const gx = r.x1 + 4, gy0 = r.y0, gy1 = r.y1;
// (only draw if there's room)
if (gx + 8 < W) {
ctx.fillStyle = "rgba(255,255,255,0.06)";
ctx.fillRect(gx, gy0, 6, gy1 - gy0);
// marker at mouseY (clamped to plot range)
const ty = Math.max(gy0, Math.min(gy1, my));
ctx.fillStyle = "rgba(255,200,120,0.9)";
ctx.fillRect(gx - 1, ty - 1, 8, 3);
}
// -------- tap-target buttons --------
function drawBtn(b, label) {
ctx.fillStyle = "rgba(255,255,255,0.06)";
ctx.strokeStyle = "rgba(255,255,255,0.18)";
ctx.lineWidth = 1;
ctx.fillRect(b.x, b.y, b.w, b.h);
ctx.strokeRect(b.x + 0.5, b.y + 0.5, b.w - 1, b.h - 1);
ctx.fillStyle = "#ddd";
ctx.font = `${W < 480 ? 12 : 12}px monospace`;
const tw = ctx.measureText(label).width;
ctx.fillText(label, b.x + (b.w - tw) / 2, b.y + b.h / 2 + 4);
}
drawBtn(clearBtn, "clear");
drawBtn(reseedBtn, "reseed");
// -------- footer --------
ctx.fillStyle = "#778";
ctx.font = `${W < 480 ? 9 : 10}px monospace`;
const hint = W < 480
? "tap plot to add · drag Y for prior"
: "click to add a point · move cursor vertically to scrub prior strength";
ctx.fillText(hint, r.x0, H - 12);
}
Comments (0)
Log in to comment.