8

Bayesian Linear Regression: Credible Bands

tap to add a point · drag Y for prior · clear/reseed buttons below

Instead of one best-fit line, Bayesian linear regression carries a whole *distribution* of plausible lines. With likelihood and Gaussian prior on , the posterior is also Gaussian with and . The solid blue line is ; the shaded band is the 95% credible interval where ; the 30 faint lines behind it are independent draws via a Cholesky factor of . **Click** anywhere in the plot to add a data point and watch the band tighten where you added it. **Move the cursor vertically** to scrub the prior precision : at the top, a strong prior squeezes the band uniformly narrow but biases the slope toward zero; at the bottom, the prior is nearly flat and the data speak for themselves — the band fans out at the extremes where you have few observations. The dashed gray line is the data-generating ground truth.

idle
287 lines · vanilla
view source
// Bayesian linear regression with credible bands.
//
// Model: y = m x + b + eps,  eps ~ N(0, sigma^2)
// Prior: [m, b] ~ N(0, tau^2 I)
// Closed-form posterior over w = [m, b]:
//   Sigma_post = ( X^T X / sigma^2  +  I / tau^2 )^{-1}
//   mu_post    = Sigma_post  X^T y / sigma^2
// Predictive variance at x:  var(y*|x) = phi^T Sigma_post phi,   phi=[x,1]
//
// Render: data, posterior-mean line, 95% credible band, faint posterior-sample lines.
// Interaction: click to add a point. mouseY scrubs prior precision 1/tau^2.

let W = 0, H = 0;
let pts = [];                 // {x, y} in data coords
let initialized = false;
let pulseAge = 0;
let frameCount = 0;

// Sampling RNG (deterministic so samples don't strobe every frame)
let sampleSeed = 1;
function rand01() {
  // mulberry32
  sampleSeed = (sampleSeed + 0x6D2B79F5) | 0;
  let t = sampleSeed;
  t = Math.imul(t ^ (t >>> 15), t | 1);
  t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
  return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
}
function randn() {
  const u1 = Math.max(1e-9, rand01());
  const u2 = rand01();
  return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
}

const X_MIN = 0, X_MAX = 10;
const Y_MIN = 0, Y_MAX = 10;
const SIGMA = 0.9;            // observation noise std (fixed)
const TRUE_M = 0.55, TRUE_B = 2.1;

// prior precision 1/tau^2 in log10. mouseY scrubs this between weak and strong.
// weak prior:    log10(1/tau^2) ≈ -3  (tau ≈ 31)
// strong prior:  log10(1/tau^2) ≈ +1.5 (tau ≈ 0.18)
let logPriorPrec = -1.5;      // default: moderately weak

function seedPoints() {
  pts = [];
  // Use a deterministic Box-Muller via Math.random for the *seed* data —
  // the sim still varies on reload (we don't reseed Math).
  for (let i = 0; i < 14; i++) {
    const x = 0.6 + Math.random() * 8.8;
    const u1 = Math.max(1e-9, Math.random()), u2 = Math.random();
    const z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
    const y = TRUE_M * x + TRUE_B + SIGMA * z;
    pts.push({ x, y: Math.max(Y_MIN + 0.05, Math.min(Y_MAX - 0.05, y)) });
  }
}

function init({ width, height }) {
  W = width; H = height;
  if (!initialized) { seedPoints(); initialized = true; }
}

function plotRect() {
  const padL = W < 480 ? 38 : 50;
  const padR = W < 480 ? 18 : 26;
  const padTop = W < 480 ? 58 : 64;
  // Extra room for the two readout lines + tap-target buttons.
  const padBot = W < 480 ? 108 : 104;
  return { x0: padL, y0: padTop, x1: W - padR, y1: H - padBot };
}

function dataToPx(x, y, r) {
  const px = r.x0 + ((x - X_MIN) / (X_MAX - X_MIN)) * (r.x1 - r.x0);
  const py = r.y1 - ((y - Y_MIN) / (Y_MAX - Y_MIN)) * (r.y1 - r.y0);
  return [px, py];
}
function pxToData(px, py, r) {
  const x = X_MIN + ((px - r.x0) / (r.x1 - r.x0)) * (X_MAX - X_MIN);
  const y = Y_MIN + ((r.y1 - py) / (r.y1 - r.y0)) * (Y_MAX - Y_MIN);
  return [x, y];
}

// Posterior over w=[m,b] given data, sigma, priorPrec = 1/tau^2.
// X has rows [x_i, 1]. Returns {muM, muB, S11, S12, S22, priorPrec, tau}
function posterior(priorPrec) {
  // A = X^T X / sigma^2 + priorPrec * I    (2x2)
  // c = X^T y / sigma^2                    (2-vector)
  // Sigma_post = A^{-1};  mu_post = Sigma_post c
  const s2 = SIGMA * SIGMA;
  let sxx = 0, sx = 0, sn = 0, sxy = 0, sy = 0;
  for (const p of pts) {
    sxx += p.x * p.x; sx += p.x; sn += 1;
    sxy += p.x * p.y; sy += p.y;
  }
  const A11 = sxx / s2 + priorPrec;
  const A12 = sx / s2;
  const A22 = sn / s2 + priorPrec;
  const c1 = sxy / s2;
  const c2 = sy / s2;
  const det = A11 * A22 - A12 * A12;
  const invDet = det > 1e-12 ? 1 / det : 0;
  // Sigma_post = (1/det) * [[A22, -A12], [-A12, A11]]
  const S11 = A22 * invDet;
  const S12 = -A12 * invDet;
  const S22 = A11 * invDet;
  // mu = Sigma * c
  const muM = S11 * c1 + S12 * c2;
  const muB = S12 * c1 + S22 * c2;
  return { muM, muB, S11, S12, S22, priorPrec, tau: 1 / Math.sqrt(priorPrec) };
}

// Predictive std for slope+intercept at x:  sqrt(phi^T Sigma phi),  phi=[x,1]
function predStd(post, x) {
  const v = x * x * post.S11 + 2 * x * post.S12 + post.S22;
  return Math.sqrt(Math.max(0, v));
}

// Sample (m, b) from N(mu, Sigma) using Cholesky of 2x2 Sigma.
function sampleW(post) {
  // L L^T = Sigma:  L11 = sqrt(S11);  L21 = S12/L11;  L22 = sqrt(S22 - L21^2)
  const L11 = Math.sqrt(Math.max(0, post.S11));
  const L21 = L11 > 1e-9 ? post.S12 / L11 : 0;
  const L22 = Math.sqrt(Math.max(0, post.S22 - L21 * L21));
  const z1 = randn(), z2 = randn();
  return {
    m: post.muM + L11 * z1,
    b: post.muB + L21 * z1 + L22 * z2,
  };
}

function tick({ ctx, dt, width, height, input }) {
  if (width !== W || height !== H) { W = width; H = height; }
  frameCount++;

  if (input.justPressed("r") || input.justPressed("R")) { seedPoints(); }
  if (input.justPressed("c") || input.justPressed("C")) { pts = []; }

  const r = plotRect();
  const mx = input.mouseX, my = input.mouseY;
  const inPlot = mx >= r.x0 && mx <= r.x1 && my >= r.y0 && my <= r.y1;

  // Tap-targets just below the plot. Sized for thumbs (≥44px tall on
  // mobile) so keyboard-less users can still clear / reseed.
  const btnH = W < 480 ? 30 : 26;
  const btnGap = 8;
  const btnY = r.y1 + (W < 480 ? 56 : 60);
  const btnW = Math.min(120, (r.x1 - r.x0 - btnGap) / 2);
  const clearBtn = { x: r.x0, y: btnY, w: btnW, h: btnH };
  const reseedBtn = { x: r.x0 + btnW + btnGap, y: btnY, w: btnW, h: btnH };

  // mouseY (any time the cursor is over the canvas, not just plot) scrubs prior.
  // Top of canvas = strong prior (high precision); bottom = weak prior.
  if (mx >= 0 && mx <= W && my >= 0 && my <= H) {
    const t = Math.max(0, Math.min(1, my / Math.max(1, H)));
    // top -> 1.5 (strong), bottom -> -3.0 (weak)
    logPriorPrec = 1.5 + (-3.0 - 1.5) * t;
  }
  const priorPrec = Math.pow(10, logPriorPrec);

  // Click to add a point inside the plot. consumeClicks returns an
  // ARRAY of {x,y,button} — not a count — so check .length, and use
  // the click's own coords (not the live cursor) so taps land where
  // the user touched, even on mobile.
  const clicks = input.consumeClicks ? input.consumeClicks() : [];
  for (const c of clicks) {
    // Tap-targets live in the footer band below the plot, full-width.
    const inClearBtn =
      c.y >= clearBtn.y && c.y <= clearBtn.y + clearBtn.h &&
      c.x >= clearBtn.x && c.x <= clearBtn.x + clearBtn.w;
    const inReseedBtn =
      c.y >= reseedBtn.y && c.y <= reseedBtn.y + reseedBtn.h &&
      c.x >= reseedBtn.x && c.x <= reseedBtn.x + reseedBtn.w;
    if (inClearBtn) { pts = []; continue; }
    if (inReseedBtn) { seedPoints(); continue; }
    const inPlotClick =
      c.x >= r.x0 && c.x <= r.x1 && c.y >= r.y0 && c.y <= r.y1;
    if (inPlotClick) {
      const [dx, dy] = pxToData(c.x, c.y, r);
      pts.push({ x: dx, y: dy });
      pulseAge = 0;
    }
  }
  pulseAge += dt;

  // Refresh sample RNG every ~0.6s so the faint lines slowly wiggle, not flicker.
  if ((frameCount % 36) === 0) sampleSeed = (Date.now() & 0x7fffffff) | 1;

  const post = posterior(priorPrec);

  // -------- background --------
  ctx.fillStyle = "#0a0a10";
  ctx.fillRect(0, 0, W, H);

  // plot bg
  ctx.fillStyle = "#13131c";
  ctx.fillRect(r.x0, r.y0, r.x1 - r.x0, r.y1 - r.y0);

  // grid
  ctx.strokeStyle = "rgba(255,255,255,0.06)";
  ctx.lineWidth = 1;
  for (let gx = 0; gx <= 10; gx++) {
    const [px] = dataToPx(gx, 0, r);
    ctx.beginPath(); ctx.moveTo(px, r.y0); ctx.lineTo(px, r.y1); ctx.stroke();
  }
  for (let gy = 0; gy <= 10; gy++) {
    const [, py] = dataToPx(0, gy, r);
    ctx.beginPath(); ctx.moveTo(r.x0, py); ctx.lineTo(r.x1, py); ctx.stroke();
  }

  // axis tick labels
  ctx.fillStyle = "#667";
  ctx.font = "10px monospace";
  for (let gx = 0; gx <= 10; gx += 2) {
    const [px] = dataToPx(gx, 0, r);
    ctx.fillText(String(gx), px - 4, r.y1 + 14);
  }
  for (let gy = 0; gy <= 10; gy += 2) {
    const [, py] = dataToPx(0, gy, r);
    ctx.fillText(String(gy), r.x0 - 18, py + 4);
  }

  // -------- credible band (95%) --------
  // Sweep x across the plot, compute y_hat = muM*x + muB and sigma_pred(x).
  // Band edges at +-1.96*sigma_pred(x).
  const NSWEEP = Math.max(40, Math.min(160, Math.floor((r.x1 - r.x0) / 4)));
  ctx.save();
  ctx.beginPath();
  ctx.rect(r.x0, r.y0, r.x1 - r.x0, r.y1 - r.y0);
  ctx.clip();

  // upper edge then back along lower
  ctx.beginPath();
  for (let i = 0; i <= NSWEEP; i++) {
    const x = X_MIN + (i / NSWEEP) * (X_MAX - X_MIN);
    const yhat = post.muM * x + post.muB;
    const s = predStd(post, x);
    const yhi = yhat + 1.96 * s;
    const [px, py] = dataToPx(x, yhi, r);
    if (i === 0) ctx.moveTo(px, py); else ctx.lineTo(px, py);
  }
  for (let i = NSWEEP; i >= 0; i--) {
    const x = X_MIN + (i / NSWEEP) * (X_MAX - X_MIN);
    const yhat = post.muM * x + post.muB;
    const s = predStd(post, x);
    const ylo = yhat - 1.96 * s;
    const [px, py] = dataToPx(x, ylo, r);
    ctx.lineTo(px, py);
  }
  ctx.closePath();
  ctx.fillStyle = "rgba(120,200,255,0.16)";
  ctx.fill();

  // posterior-sample lines (faint)
  const NSAMP = 30;
  ctx.strokeStyle = "rgba(180,220,255,0.10)";
  ctx.lineWidth = 1;
  for (let s = 0; s < NSAMP; s++) {
    const w = sampleW(post);
    const [x0p, y0p] = dataToPx(X_MIN, w.m * X_MIN + w.b, r);
    const [x1p, y1p] = dataToPx(X_MAX, w.m * X_MAX + w.b, r);
    ctx.beginPath();
    ctx.moveTo(x0p, y0p); ctx.lineTo(x1p, y1p); ctx.stroke();
  }

  // true line (thin dashed gray) — pedagogical reference
  ctx.setLineDash([4, 4]);
  ctx.strokeStyle = "rgba(255,255,255,0.22)";
  ctx.lineWidth = 1;
  {
    const [x0p, y0p] = dataToPx(X_MIN, TRUE_M * X_MIN + TRUE_B, r);
    const [x1p, y1p] = dataToPx(X_MAX, TRUE_M * X_MAX + TRUE_B, r);
    ctx.beginPath();
    ctx.moveTo(x0p, y0p); ctx.lineTo(x1p, y1p); ctx.stroke();
  }
  ctx.setLineDash([]);

  // posterior-mean line (solid blue)
  ctx.strokeStyle = "rgba(120,200,255,1)";
  ctx.lineWidth = 2;
  {
    const [x0p, y0p] = dataToPx(X_MIN, post.muM * X_MIN + post.muB, r);
    const [x1p, y1p] = dataToPx(X_MAX, post.muM * X_MAX + post.muB, r);
    ctx.beginPath();
    ctx.moveTo(x0p, y0p); ctx.lineTo(x1p, y1p); ctx.stroke();
  }
  ctx.lineWidth = 1;
  ctx.restore();

  // -------- data points --------
  for (let i = 0; i < pts.length; i++) {
    const p = pts[i];
    const [px, py] = dataToPx(p.x, p.y, r);
    ctx.fillStyle = "rgba(230,230,240,0.95)";
    ctx.beginPath();
    ctx.arc(px, py, 3.5, 0, Math.PI * 2);
    ctx.fill();
  }
  // pulse on most recent
  if (pulseAge < 0.6 && pts.length > 0) {
    const p = pts[pts.length - 1];
    const [px, py] = dataToPx(p.x, p.y, r);
    const rad = 5 + pulseAge * 30;
    ctx.strokeStyle = `rgba(120,255,150,${(1 - pulseAge / 0.6).toFixed(3)})`;
    ctx.lineWidth = 2;
    ctx.beginPath();
    ctx.arc(px, py, rad, 0, Math.PI * 2);
    ctx.stroke();
    ctx.lineWidth = 1;
  }

  // -------- header --------
  ctx.fillStyle = "#e8e8f0";
  ctx.font = `bold ${W < 480 ? 13 : 16}px monospace`;
  ctx.fillText("Bayesian Linear Regression", r.x0, W < 480 ? 22 : 28);
  ctx.font = `${W < 480 ? 10 : 11}px monospace`;
  ctx.fillStyle = "#aab";
  if (W >= 380) {
    ctx.fillText("posterior mean (blue) · 95% credible band · 30 posterior samples", r.x0, W < 480 ? 38 : 46);
  } else {
    ctx.fillText("mean · band · samples", r.x0, 38);
  }

  // -------- readouts --------
  ctx.font = `${W < 480 ? 11 : 13}px monospace`;
  const lineY = r.y1 + 30;
  const lineY2 = r.y1 + (W < 480 ? 46 : 50);

  // posterior marginal stds
  const stdM = Math.sqrt(Math.max(0, post.S11));
  const stdB = Math.sqrt(Math.max(0, post.S22));

  if (W < 480) {
    ctx.fillStyle = "#9cf";
    ctx.fillText(`m = ${post.muM.toFixed(3)} ± ${stdM.toFixed(3)}`, r.x0, lineY);
    ctx.fillStyle = "#9cf";
    ctx.fillText(`b = ${post.muB.toFixed(3)} ± ${stdB.toFixed(3)}`, r.x0, lineY2);
    ctx.fillStyle = "#fc6";
    ctx.fillText(`τ = ${post.tau.toFixed(2)}  n = ${pts.length}`, r.x0 + 150, lineY);
  } else {
    ctx.fillStyle = "#9cf";
    ctx.fillText(`m = ${post.muM.toFixed(3)} ± ${stdM.toFixed(3)}`, r.x0, lineY);
    ctx.fillStyle = "#9cf";
    ctx.fillText(`b = ${post.muB.toFixed(3)} ± ${stdB.toFixed(3)}`, r.x0 + 200, lineY);
    ctx.fillStyle = "#fc6";
    ctx.fillText(`τ = ${post.tau.toFixed(2)}  (1/τ² = ${priorPrec.toFixed(3)})`, r.x0 + 400, lineY);
    ctx.fillStyle = "#aab";
    ctx.fillText(`n = ${pts.length}   σ = ${SIGMA.toFixed(2)}   true: m=${TRUE_M}, b=${TRUE_B}`, r.x0, lineY2);
  }

  // -------- prior strength gauge (right edge, top → strong, bot → weak) --------
  const gx = r.x1 + 4, gy0 = r.y0, gy1 = r.y1;
  // (only draw if there's room)
  if (gx + 8 < W) {
    ctx.fillStyle = "rgba(255,255,255,0.06)";
    ctx.fillRect(gx, gy0, 6, gy1 - gy0);
    // marker at mouseY (clamped to plot range)
    const ty = Math.max(gy0, Math.min(gy1, my));
    ctx.fillStyle = "rgba(255,200,120,0.9)";
    ctx.fillRect(gx - 1, ty - 1, 8, 3);
  }

  // -------- tap-target buttons --------
  function drawBtn(b, label) {
    ctx.fillStyle = "rgba(255,255,255,0.06)";
    ctx.strokeStyle = "rgba(255,255,255,0.18)";
    ctx.lineWidth = 1;
    ctx.fillRect(b.x, b.y, b.w, b.h);
    ctx.strokeRect(b.x + 0.5, b.y + 0.5, b.w - 1, b.h - 1);
    ctx.fillStyle = "#ddd";
    ctx.font = `${W < 480 ? 12 : 12}px monospace`;
    const tw = ctx.measureText(label).width;
    ctx.fillText(label, b.x + (b.w - tw) / 2, b.y + b.h / 2 + 4);
  }
  drawBtn(clearBtn, "clear");
  drawBtn(reseedBtn, "reseed");

  // -------- footer --------
  ctx.fillStyle = "#778";
  ctx.font = `${W < 480 ? 9 : 10}px monospace`;
  const hint = W < 480
    ? "tap plot to add · drag Y for prior"
    : "click to add a point  ·  move cursor vertically to scrub prior strength";
  ctx.fillText(hint, r.x0, H - 12);
}

Comments (0)

Log in to comment.