16

Wasserstein Barycenter: Three Shapes

Given three probability measures on the plane (here a square, a disk, and an L-tetromino, each as a uniform density on a grid), their *Wasserstein barycenter* with weights is the measure

where is the quadratic optimal-transport distance. Unlike pixelwise convex combination — which would simply superimpose the shapes — the barycenter respects mass transport: each particle of is a weighted average of three coupled particles, one from each shape. The result is a genuine *morph*. We approximate it cheaply via displacement interpolation: sample Lloyd-relaxed points per shape, couple them by a common radial canonical order, and average positions before splatting back to a density field rendered in viridis. The weight traces the boundary of the 2-simplex over a 12-second loop; a small triangle widget shows its current location. Inspired by Gabriel Peyré's optimal-transport visualizations.

idle
374 lines · vanilla
view source
// Wasserstein Barycenter morph between three binary shapes.
// Cheap displacement-interpolation approximation:
//   - Sample each shape as N points (Lloyd-relaxed for even coverage).
//   - Sort all three point sets in a common 1D order (Hilbert-like via
//     angle from centroid) so points correspond across shapes — gives a
//     coarse but mass-preserving coupling.
//   - For weights (w1,w2,w3), barycenter point_k = w1*A_k + w2*B_k + w3*C_k.
//   - Splat all N points with a small Gaussian kernel onto a 128x128 grid.
//   - Display via viridis.
//
// We precompute K=40 barycenters along the triangle-edge path once in init.
// Per-frame: crossfade two neighbouring precomputed density fields. No
// allocation in tick.

const GRID = 128;
const N_POINTS = 900;
const K_FRAMES = 40;
const LOOP_SECONDS = 12.0;
const SPLAT_SIGMA = 1.15; // grid cells
const SPLAT_RADIUS = 3;   // truncate kernel

// Precomputed Gaussian splat stamp (radius 3 -> 7x7).
let SPLAT_STAMP = null;

let W, H;
let frames; // Float32Array[K_FRAMES] each GRID*GRID
let frameMax; // per-frame max for normalization
let imgBuf;  // ImageData for blit
let imgData; // Uint8ClampedArray view (RGBA)
let triCorners; // [{x,y}, ...] in canvas coords for the weight widget
let cornerThumbs; // Float32Array per corner shape (128x128) — for the small previews
let cornerMax;

let timeAcc = 0;

// ---------- viridis (8 stops, lerp) ----------
const VIRIDIS = [
  [68, 1, 84], [72, 35, 116], [64, 67, 135], [52, 94, 141],
  [41, 120, 142], [32, 144, 140], [34, 167, 132], [68, 190, 112],
  [121, 209, 81], [189, 222, 38], [253, 231, 36],
];
function viridis(t) {
  if (t <= 0) return VIRIDIS[0];
  if (t >= 1) return VIRIDIS[VIRIDIS.length - 1];
  const f = t * (VIRIDIS.length - 1);
  const i = Math.floor(f);
  const u = f - i;
  const a = VIRIDIS[i], b = VIRIDIS[i + 1];
  return [
    a[0] + (b[0] - a[0]) * u,
    a[1] + (b[1] - a[1]) * u,
    a[2] + (b[2] - a[2]) * u,
  ];
}

// ---------- binary shapes on a 128x128 grid ----------
function shapeSquare() {
  const m = new Uint8Array(GRID * GRID);
  const lo = GRID * 0.28, hi = GRID * 0.72;
  for (let y = 0; y < GRID; y++) {
    for (let x = 0; x < GRID; x++) {
      if (x >= lo && x <= hi && y >= lo && y <= hi) m[y * GRID + x] = 1;
    }
  }
  return m;
}
function shapeDisk() {
  const m = new Uint8Array(GRID * GRID);
  const cx = GRID * 0.5, cy = GRID * 0.5;
  const r = GRID * 0.26;
  for (let y = 0; y < GRID; y++) {
    for (let x = 0; x < GRID; x++) {
      const dx = x - cx, dy = y - cy;
      if (dx * dx + dy * dy <= r * r) m[y * GRID + x] = 1;
    }
  }
  return m;
}
function shapeL() {
  // L-tetromino-like: a vertical bar + a foot, scaled to fill ~half the canvas.
  const m = new Uint8Array(GRID * GRID);
  // Vertical bar
  const bx0 = GRID * 0.34, bx1 = GRID * 0.50;
  const by0 = GRID * 0.22, by1 = GRID * 0.78;
  // Foot
  const fx0 = GRID * 0.34, fx1 = GRID * 0.74;
  const fy0 = GRID * 0.62, fy1 = GRID * 0.78;
  for (let y = 0; y < GRID; y++) {
    for (let x = 0; x < GRID; x++) {
      const inBar = x >= bx0 && x <= bx1 && y >= by0 && y <= by1;
      const inFoot = x >= fx0 && x <= fx1 && y >= fy0 && y <= fy1;
      if (inBar || inFoot) m[y * GRID + x] = 1;
    }
  }
  return m;
}

// ---------- sample N points uniformly inside a binary mask ----------
// Use stratified rejection: visit pixels in scrambled order, take ones inside.
function samplePoints(mask, n) {
  // Collect interior pixel coordinates.
  const inside = [];
  for (let y = 0; y < GRID; y++) {
    for (let x = 0; x < GRID; x++) {
      if (mask[y * GRID + x]) inside.push(x, y);
    }
  }
  const pixCount = inside.length / 2;
  // Sample n positions with a deterministic pseudo-random permutation:
  // pick every (pixCount / n) with a tiny offset for sub-pixel placement.
  const pts = new Float32Array(n * 2);
  // Use halton sequence for sub-pixel jitter inside each picked cell.
  for (let i = 0; i < n; i++) {
    const t = (i + 0.5) / n;
    const j = Math.min(pixCount - 1, Math.floor(t * pixCount));
    const px = inside[j * 2];
    const py = inside[j * 2 + 1];
    // jitter by halton base 2 / 3
    let jx = 0, jy = 0, f = 0.5, k = i + 1;
    while (k > 0) { jx += f * (k & 1); k >>= 1; f *= 0.5; }
    f = 1 / 3; k = i + 1;
    while (k > 0) { jy += f * (k % 3); k = Math.floor(k / 3); f /= 3; }
    pts[i * 2] = px + jx;
    pts[i * 2 + 1] = py + jy;
  }
  return pts;
}

// ---------- Lloyd relaxation (a couple iterations) for even coverage ----------
// Cheap approximation: bucket points into a coarse grid, replace each point
// with its bucket centroid. 2 passes is enough to dampen clumps.
function lloydRelax(pts, mask, iters) {
  const buckets = 16; // 16x16 coarse buckets
  const cell = GRID / buckets;
  for (let it = 0; it < iters; it++) {
    const sumX = new Float32Array(buckets * buckets);
    const sumY = new Float32Array(buckets * buckets);
    const cnt = new Int32Array(buckets * buckets);
    for (let i = 0; i < pts.length; i += 2) {
      const bx = Math.min(buckets - 1, Math.floor(pts[i] / cell));
      const by = Math.min(buckets - 1, Math.floor(pts[i + 1] / cell));
      const b = by * buckets + bx;
      sumX[b] += pts[i];
      sumY[b] += pts[i + 1];
      cnt[b]++;
    }
    const cx = new Float32Array(buckets * buckets);
    const cy = new Float32Array(buckets * buckets);
    for (let b = 0; b < buckets * buckets; b++) {
      if (cnt[b] > 0) {
        cx[b] = sumX[b] / cnt[b];
        cy[b] = sumY[b] / cnt[b];
      }
    }
    for (let i = 0; i < pts.length; i += 2) {
      const bx = Math.min(buckets - 1, Math.floor(pts[i] / cell));
      const by = Math.min(buckets - 1, Math.floor(pts[i + 1] / cell));
      const b = by * buckets + bx;
      // Drift each point partway toward its bucket centroid, clipped to mask.
      let nx = pts[i] + 0.5 * (cx[b] - pts[i]);
      let ny = pts[i + 1] + 0.5 * (cy[b] - pts[i + 1]);
      // Clamp inside mask: if outside, keep original.
      const ix = Math.max(0, Math.min(GRID - 1, Math.floor(nx)));
      const iy = Math.max(0, Math.min(GRID - 1, Math.floor(ny)));
      if (mask[iy * GRID + ix]) {
        pts[i] = nx;
        pts[i + 1] = ny;
      }
    }
  }
}

// ---------- Sort point set in a common canonical order ----------
// Compute centroid then sort by (angle, radius). This couples the three sets
// in a way that approximates 2D optimal transport for radially-similar shapes.
function canonicalSort(pts) {
  // Find centroid
  let cx = 0, cy = 0;
  const n = pts.length / 2;
  for (let i = 0; i < n; i++) { cx += pts[i * 2]; cy += pts[i * 2 + 1]; }
  cx /= n; cy /= n;
  // Build [angle, r, x, y] tuples
  const keys = new Float64Array(n);
  const idx = new Int32Array(n);
  for (let i = 0; i < n; i++) {
    const dx = pts[i * 2] - cx;
    const dy = pts[i * 2 + 1] - cy;
    const ang = Math.atan2(dy, dx);
    const r = Math.sqrt(dx * dx + dy * dy);
    // Compose a single sortable key: bin angle into 64 sectors, then by r.
    const sector = Math.floor(((ang + Math.PI) / (2 * Math.PI)) * 64);
    keys[i] = sector * 10000 + r;
    idx[i] = i;
  }
  // Sort idx by keys
  const order = Array.from(idx).sort((a, b) => keys[a] - keys[b]);
  const sorted = new Float32Array(pts.length);
  for (let i = 0; i < n; i++) {
    sorted[i * 2] = pts[order[i] * 2];
    sorted[i * 2 + 1] = pts[order[i] * 2 + 1];
  }
  return sorted;
}

// ---------- Build a Gaussian splat stamp ----------
function buildStamp() {
  const r = SPLAT_RADIUS;
  const size = 2 * r + 1;
  const stamp = new Float32Array(size * size);
  const s2 = SPLAT_SIGMA * SPLAT_SIGMA;
  for (let dy = -r; dy <= r; dy++) {
    for (let dx = -r; dx <= r; dx++) {
      stamp[(dy + r) * size + (dx + r)] = Math.exp(-(dx * dx + dy * dy) / (2 * s2));
    }
  }
  return stamp;
}

// ---------- Splat point set to density field ----------
function splat(pts, density) {
  density.fill(0);
  const r = SPLAT_RADIUS;
  const size = 2 * r + 1;
  const n = pts.length / 2;
  for (let i = 0; i < n; i++) {
    const px = pts[i * 2];
    const py = pts[i * 2 + 1];
    const ix = Math.round(px);
    const iy = Math.round(py);
    for (let dy = -r; dy <= r; dy++) {
      const y = iy + dy;
      if (y < 0 || y >= GRID) continue;
      for (let dx = -r; dx <= r; dx++) {
        const x = ix + dx;
        if (x < 0 || x >= GRID) continue;
        density[y * GRID + x] += SPLAT_STAMP[(dy + r) * size + (dx + r)];
      }
    }
  }
}

// ---------- Triangle path of weights (w1,w2,w3), closed loop ----------
// Travel A->B->C->A along the simplex edges (each edge takes 1/3 of the loop).
function weightsAt(u) {
  // u in [0,1)
  const t = (u * 3) % 3;
  const seg = Math.floor(t);
  const f = t - seg;
  if (seg === 0) return [1 - f, f, 0];
  if (seg === 1) return [0, 1 - f, f];
  return [f, 0, 1 - f];
}

// ---------- Precompute K barycenter density frames ----------
function precomputeFrames(ptsA, ptsB, ptsC) {
  const buf = new Float32Array(GRID * GRID);
  const tmp = new Float32Array(N_POINTS * 2);
  for (let k = 0; k < K_FRAMES; k++) {
    const u = k / K_FRAMES;
    const w = weightsAt(u);
    for (let i = 0; i < N_POINTS; i++) {
      tmp[i * 2] = w[0] * ptsA[i * 2] + w[1] * ptsB[i * 2] + w[2] * ptsC[i * 2];
      tmp[i * 2 + 1] = w[0] * ptsA[i * 2 + 1] + w[1] * ptsB[i * 2 + 1] + w[2] * ptsC[i * 2 + 1];
    }
    const d = new Float32Array(GRID * GRID);
    splat(tmp, d);
    frames[k] = d;
    // record max for normalization
    let m = 0;
    for (let i = 0; i < d.length; i++) if (d[i] > m) m = d[i];
    frameMax[k] = m || 1;
    // also reuse buf var (unused)
  }
  void buf;
}

// ---------- Render a density field into a region of imgData ----------
// dstX,dstY,dstW,dstH in pixel coords. Uses nearest-neighbour for speed.
function blitDensity(density, dmax, dstX, dstY, dstW, dstH) {
  for (let py = 0; py < dstH; py++) {
    const sy = Math.min(GRID - 1, Math.floor(py * GRID / dstH));
    for (let px = 0; px < dstW; px++) {
      const sx = Math.min(GRID - 1, Math.floor(px * GRID / dstW));
      const v = density[sy * GRID + sx] / dmax;
      const t = Math.min(1, Math.max(0, v));
      const c = viridis(t);
      const i = ((dstY + py) * W + (dstX + px)) * 4;
      imgData[i] = c[0] | 0;
      imgData[i + 1] = c[1] | 0;
      imgData[i + 2] = c[2] | 0;
      imgData[i + 3] = 255;
    }
  }
}

// ---------- Crossfade blit: blend two density frames ----------
function blitCrossfade(dA, mA, dB, mB, alpha, dstX, dstY, dstW, dstH) {
  for (let py = 0; py < dstH; py++) {
    const sy = Math.min(GRID - 1, Math.floor(py * GRID / dstH));
    for (let px = 0; px < dstW; px++) {
      const sx = Math.min(GRID - 1, Math.floor(px * GRID / dstW));
      const va = dA[sy * GRID + sx] / mA;
      const vb = dB[sy * GRID + sx] / mB;
      const v = (1 - alpha) * va + alpha * vb;
      const t = Math.min(1, Math.max(0, v));
      const c = viridis(t);
      const i = ((dstY + py) * W + (dstX + px)) * 4;
      imgData[i] = c[0] | 0;
      imgData[i + 1] = c[1] | 0;
      imgData[i + 2] = c[2] | 0;
      imgData[i + 3] = 255;
    }
  }
}

function init({ canvas, ctx, width, height }) {
  W = width;
  H = height;
  SPLAT_STAMP = buildStamp();

  // Build shapes
  const mA = shapeSquare();
  const mB = shapeDisk();
  const mC = shapeL();

  // Sample points
  let pA = samplePoints(mA, N_POINTS);
  let pB = samplePoints(mB, N_POINTS);
  let pC = samplePoints(mC, N_POINTS);

  // Lloyd relax 2 iterations for nicer coverage
  lloydRelax(pA, mA, 2);
  lloydRelax(pB, mB, 2);
  lloydRelax(pC, mC, 2);

  // Sort into a common canonical order so the three sets correspond.
  const sA = canonicalSort(pA);
  const sB = canonicalSort(pB);
  const sC = canonicalSort(pC);

  // Precompute corner thumbs (density fields for the three shapes themselves).
  cornerThumbs = [new Float32Array(GRID * GRID), new Float32Array(GRID * GRID), new Float32Array(GRID * GRID)];
  cornerMax = new Float32Array(3);
  splat(sA, cornerThumbs[0]);
  splat(sB, cornerThumbs[1]);
  splat(sC, cornerThumbs[2]);
  for (let k = 0; k < 3; k++) {
    let m = 0;
    for (let i = 0; i < cornerThumbs[k].length; i++) if (cornerThumbs[k][i] > m) m = cornerThumbs[k][i];
    cornerMax[k] = m || 1;
  }

  // Precompute the K barycenter frames along the loop.
  frames = new Array(K_FRAMES);
  frameMax = new Float32Array(K_FRAMES);
  precomputeFrames(sA, sB, sC);

  // ImageData for full-canvas blits.
  imgBuf = ctx.createImageData(W, H);
  imgData = imgBuf.data;

  // Triangle widget corner positions are computed in tick (depend on layout).
  triCorners = null;

  // Paint initial background.
  ctx.fillStyle = '#05060a';
  ctx.fillRect(0, 0, W, H);
}

function tick({ ctx, dt, time, width, height }) {
  // Handle resize: rebuild ImageData buffer only.
  if (width !== W || height !== H) {
    W = width;
    H = height;
    imgBuf = ctx.createImageData(W, H);
    imgData = imgBuf.data;
  }

  timeAcc += dt;
  const u = (timeAcc / LOOP_SECONDS) % 1;

  // Pick the two neighbouring precomputed frames and crossfade.
  const fIdx = u * K_FRAMES;
  const k0 = Math.floor(fIdx) % K_FRAMES;
  const k1 = (k0 + 1) % K_FRAMES;
  const alpha = fIdx - Math.floor(fIdx);

  // Main barycenter area: centered square taking up the bulk of canvas.
  const margin = Math.floor(Math.min(W, H) * 0.06);
  const mainSize = Math.min(W, H) - 2 * margin;
  const mainX = Math.floor((W - mainSize) / 2);
  const mainY = Math.floor((H - mainSize) / 2);

  // Fill background first
  for (let i = 0; i < imgData.length; i += 4) {
    imgData[i] = 5;
    imgData[i + 1] = 6;
    imgData[i + 2] = 10;
    imgData[i + 3] = 255;
  }

  // Blit the barycenter
  blitCrossfade(
    frames[k0], frameMax[k0],
    frames[k1], frameMax[k1],
    alpha,
    mainX, mainY, mainSize, mainSize
  );

  // Triangle widget in the upper-right.
  // Three corner shape thumbs at vertices of an equilateral triangle.
  const widgetSize = Math.floor(Math.min(W, H) * 0.22);
  const wx = W - widgetSize - margin;
  const wy = margin;
  const cx = wx + widgetSize / 2;
  const cy = wy + widgetSize / 2;
  const tr = widgetSize * 0.42;
  // Vertices: top, bottom-right, bottom-left (corresponding to A, B, C)
  const verts = [
    [cx, cy - tr],
    [cx + tr * Math.cos(Math.PI / 6), cy + tr * Math.sin(Math.PI / 6)],
    [cx - tr * Math.cos(Math.PI / 6), cy + tr * Math.sin(Math.PI / 6)],
  ];
  // Thumb size
  const thumbS = Math.floor(widgetSize * 0.20);

  // Draw thumbs into imgData
  for (let v = 0; v < 3; v++) {
    const tx = Math.round(verts[v][0] - thumbS / 2);
    const ty = Math.round(verts[v][1] - thumbS / 2);
    // Clip
    const x0 = Math.max(0, tx);
    const y0 = Math.max(0, ty);
    const x1 = Math.min(W, tx + thumbS);
    const y1 = Math.min(H, ty + thumbS);
    for (let py = y0; py < y1; py++) {
      const sy = Math.min(GRID - 1, Math.floor((py - ty) * GRID / thumbS));
      for (let px = x0; px < x1; px++) {
        const sx = Math.min(GRID - 1, Math.floor((px - tx) * GRID / thumbS));
        const val = cornerThumbs[v][sy * GRID + sx] / cornerMax[v];
        const t = Math.min(1, Math.max(0, val));
        const c = viridis(t);
        const i = (py * W + px) * 4;
        imgData[i] = c[0] | 0;
        imgData[i + 1] = c[1] | 0;
        imgData[i + 2] = c[2] | 0;
        imgData[i + 3] = 255;
      }
    }
  }

  // Push image data to canvas now (lines/dots drawn afterwards over).
  ctx.putImageData(imgBuf, 0, 0);

  // Draw triangle edges between thumb centers.
  ctx.strokeStyle = 'rgba(220, 230, 240, 0.55)';
  ctx.lineWidth = 1.2;
  ctx.beginPath();
  ctx.moveTo(verts[0][0], verts[0][1]);
  ctx.lineTo(verts[1][0], verts[1][1]);
  ctx.lineTo(verts[2][0], verts[2][1]);
  ctx.closePath();
  ctx.stroke();

  // Current weights -> barycentric position inside the triangle.
  const w = weightsAt(u);
  const dotX = w[0] * verts[0][0] + w[1] * verts[1][0] + w[2] * verts[2][0];
  const dotY = w[0] * verts[0][1] + w[1] * verts[1][1] + w[2] * verts[2][1];

  // Glow halo
  ctx.fillStyle = 'rgba(253, 231, 36, 0.18)';
  ctx.beginPath();
  ctx.arc(dotX, dotY, 9, 0, Math.PI * 2);
  ctx.fill();
  // Dot
  ctx.fillStyle = 'rgba(253, 231, 36, 1.0)';
  ctx.beginPath();
  ctx.arc(dotX, dotY, 3.5, 0, Math.PI * 2);
  ctx.fill();

  // Tiny weight readout under the triangle.
  ctx.fillStyle = 'rgba(220, 230, 240, 0.8)';
  ctx.font = `${Math.max(10, Math.floor(widgetSize * 0.07))}px ui-monospace, monospace`;
  ctx.textAlign = 'center';
  const wTxt = `${w[0].toFixed(2)}  ${w[1].toFixed(2)}  ${w[2].toFixed(2)}`;
  ctx.fillText(wTxt, cx, wy + widgetSize - 4);
}

Comments (0)

Log in to comment.