16
Wasserstein Barycenter: Three Shapes
where is the quadratic optimal-transport distance. Unlike pixelwise convex combination — which would simply superimpose the shapes — the barycenter respects mass transport: each particle of is a weighted average of three coupled particles, one from each shape. The result is a genuine *morph*. We approximate it cheaply via displacement interpolation: sample Lloyd-relaxed points per shape, couple them by a common radial canonical order, and average positions before splatting back to a density field rendered in viridis. The weight traces the boundary of the 2-simplex over a 12-second loop; a small triangle widget shows its current location. Inspired by Gabriel Peyré's optimal-transport visualizations.
idle
374 lines · vanilla
view source
// Wasserstein Barycenter morph between three binary shapes.
// Cheap displacement-interpolation approximation:
// - Sample each shape as N points (Lloyd-relaxed for even coverage).
// - Sort all three point sets in a common 1D order (Hilbert-like via
// angle from centroid) so points correspond across shapes — gives a
// coarse but mass-preserving coupling.
// - For weights (w1,w2,w3), barycenter point_k = w1*A_k + w2*B_k + w3*C_k.
// - Splat all N points with a small Gaussian kernel onto a 128x128 grid.
// - Display via viridis.
//
// We precompute K=40 barycenters along the triangle-edge path once in init.
// Per-frame: crossfade two neighbouring precomputed density fields. No
// allocation in tick.
const GRID = 128;
const N_POINTS = 900;
const K_FRAMES = 40;
const LOOP_SECONDS = 12.0;
const SPLAT_SIGMA = 1.15; // grid cells
const SPLAT_RADIUS = 3; // truncate kernel
// Precomputed Gaussian splat stamp (radius 3 -> 7x7).
let SPLAT_STAMP = null;
let W, H;
let frames; // Float32Array[K_FRAMES] each GRID*GRID
let frameMax; // per-frame max for normalization
let imgBuf; // ImageData for blit
let imgData; // Uint8ClampedArray view (RGBA)
let triCorners; // [{x,y}, ...] in canvas coords for the weight widget
let cornerThumbs; // Float32Array per corner shape (128x128) — for the small previews
let cornerMax;
let timeAcc = 0;
// ---------- viridis (8 stops, lerp) ----------
const VIRIDIS = [
[68, 1, 84], [72, 35, 116], [64, 67, 135], [52, 94, 141],
[41, 120, 142], [32, 144, 140], [34, 167, 132], [68, 190, 112],
[121, 209, 81], [189, 222, 38], [253, 231, 36],
];
function viridis(t) {
if (t <= 0) return VIRIDIS[0];
if (t >= 1) return VIRIDIS[VIRIDIS.length - 1];
const f = t * (VIRIDIS.length - 1);
const i = Math.floor(f);
const u = f - i;
const a = VIRIDIS[i], b = VIRIDIS[i + 1];
return [
a[0] + (b[0] - a[0]) * u,
a[1] + (b[1] - a[1]) * u,
a[2] + (b[2] - a[2]) * u,
];
}
// ---------- binary shapes on a 128x128 grid ----------
function shapeSquare() {
const m = new Uint8Array(GRID * GRID);
const lo = GRID * 0.28, hi = GRID * 0.72;
for (let y = 0; y < GRID; y++) {
for (let x = 0; x < GRID; x++) {
if (x >= lo && x <= hi && y >= lo && y <= hi) m[y * GRID + x] = 1;
}
}
return m;
}
function shapeDisk() {
const m = new Uint8Array(GRID * GRID);
const cx = GRID * 0.5, cy = GRID * 0.5;
const r = GRID * 0.26;
for (let y = 0; y < GRID; y++) {
for (let x = 0; x < GRID; x++) {
const dx = x - cx, dy = y - cy;
if (dx * dx + dy * dy <= r * r) m[y * GRID + x] = 1;
}
}
return m;
}
function shapeL() {
// L-tetromino-like: a vertical bar + a foot, scaled to fill ~half the canvas.
const m = new Uint8Array(GRID * GRID);
// Vertical bar
const bx0 = GRID * 0.34, bx1 = GRID * 0.50;
const by0 = GRID * 0.22, by1 = GRID * 0.78;
// Foot
const fx0 = GRID * 0.34, fx1 = GRID * 0.74;
const fy0 = GRID * 0.62, fy1 = GRID * 0.78;
for (let y = 0; y < GRID; y++) {
for (let x = 0; x < GRID; x++) {
const inBar = x >= bx0 && x <= bx1 && y >= by0 && y <= by1;
const inFoot = x >= fx0 && x <= fx1 && y >= fy0 && y <= fy1;
if (inBar || inFoot) m[y * GRID + x] = 1;
}
}
return m;
}
// ---------- sample N points uniformly inside a binary mask ----------
// Use stratified rejection: visit pixels in scrambled order, take ones inside.
function samplePoints(mask, n) {
// Collect interior pixel coordinates.
const inside = [];
for (let y = 0; y < GRID; y++) {
for (let x = 0; x < GRID; x++) {
if (mask[y * GRID + x]) inside.push(x, y);
}
}
const pixCount = inside.length / 2;
// Sample n positions with a deterministic pseudo-random permutation:
// pick every (pixCount / n) with a tiny offset for sub-pixel placement.
const pts = new Float32Array(n * 2);
// Use halton sequence for sub-pixel jitter inside each picked cell.
for (let i = 0; i < n; i++) {
const t = (i + 0.5) / n;
const j = Math.min(pixCount - 1, Math.floor(t * pixCount));
const px = inside[j * 2];
const py = inside[j * 2 + 1];
// jitter by halton base 2 / 3
let jx = 0, jy = 0, f = 0.5, k = i + 1;
while (k > 0) { jx += f * (k & 1); k >>= 1; f *= 0.5; }
f = 1 / 3; k = i + 1;
while (k > 0) { jy += f * (k % 3); k = Math.floor(k / 3); f /= 3; }
pts[i * 2] = px + jx;
pts[i * 2 + 1] = py + jy;
}
return pts;
}
// ---------- Lloyd relaxation (a couple iterations) for even coverage ----------
// Cheap approximation: bucket points into a coarse grid, replace each point
// with its bucket centroid. 2 passes is enough to dampen clumps.
function lloydRelax(pts, mask, iters) {
const buckets = 16; // 16x16 coarse buckets
const cell = GRID / buckets;
for (let it = 0; it < iters; it++) {
const sumX = new Float32Array(buckets * buckets);
const sumY = new Float32Array(buckets * buckets);
const cnt = new Int32Array(buckets * buckets);
for (let i = 0; i < pts.length; i += 2) {
const bx = Math.min(buckets - 1, Math.floor(pts[i] / cell));
const by = Math.min(buckets - 1, Math.floor(pts[i + 1] / cell));
const b = by * buckets + bx;
sumX[b] += pts[i];
sumY[b] += pts[i + 1];
cnt[b]++;
}
const cx = new Float32Array(buckets * buckets);
const cy = new Float32Array(buckets * buckets);
for (let b = 0; b < buckets * buckets; b++) {
if (cnt[b] > 0) {
cx[b] = sumX[b] / cnt[b];
cy[b] = sumY[b] / cnt[b];
}
}
for (let i = 0; i < pts.length; i += 2) {
const bx = Math.min(buckets - 1, Math.floor(pts[i] / cell));
const by = Math.min(buckets - 1, Math.floor(pts[i + 1] / cell));
const b = by * buckets + bx;
// Drift each point partway toward its bucket centroid, clipped to mask.
let nx = pts[i] + 0.5 * (cx[b] - pts[i]);
let ny = pts[i + 1] + 0.5 * (cy[b] - pts[i + 1]);
// Clamp inside mask: if outside, keep original.
const ix = Math.max(0, Math.min(GRID - 1, Math.floor(nx)));
const iy = Math.max(0, Math.min(GRID - 1, Math.floor(ny)));
if (mask[iy * GRID + ix]) {
pts[i] = nx;
pts[i + 1] = ny;
}
}
}
}
// ---------- Sort point set in a common canonical order ----------
// Compute centroid then sort by (angle, radius). This couples the three sets
// in a way that approximates 2D optimal transport for radially-similar shapes.
function canonicalSort(pts) {
// Find centroid
let cx = 0, cy = 0;
const n = pts.length / 2;
for (let i = 0; i < n; i++) { cx += pts[i * 2]; cy += pts[i * 2 + 1]; }
cx /= n; cy /= n;
// Build [angle, r, x, y] tuples
const keys = new Float64Array(n);
const idx = new Int32Array(n);
for (let i = 0; i < n; i++) {
const dx = pts[i * 2] - cx;
const dy = pts[i * 2 + 1] - cy;
const ang = Math.atan2(dy, dx);
const r = Math.sqrt(dx * dx + dy * dy);
// Compose a single sortable key: bin angle into 64 sectors, then by r.
const sector = Math.floor(((ang + Math.PI) / (2 * Math.PI)) * 64);
keys[i] = sector * 10000 + r;
idx[i] = i;
}
// Sort idx by keys
const order = Array.from(idx).sort((a, b) => keys[a] - keys[b]);
const sorted = new Float32Array(pts.length);
for (let i = 0; i < n; i++) {
sorted[i * 2] = pts[order[i] * 2];
sorted[i * 2 + 1] = pts[order[i] * 2 + 1];
}
return sorted;
}
// ---------- Build a Gaussian splat stamp ----------
function buildStamp() {
const r = SPLAT_RADIUS;
const size = 2 * r + 1;
const stamp = new Float32Array(size * size);
const s2 = SPLAT_SIGMA * SPLAT_SIGMA;
for (let dy = -r; dy <= r; dy++) {
for (let dx = -r; dx <= r; dx++) {
stamp[(dy + r) * size + (dx + r)] = Math.exp(-(dx * dx + dy * dy) / (2 * s2));
}
}
return stamp;
}
// ---------- Splat point set to density field ----------
function splat(pts, density) {
density.fill(0);
const r = SPLAT_RADIUS;
const size = 2 * r + 1;
const n = pts.length / 2;
for (let i = 0; i < n; i++) {
const px = pts[i * 2];
const py = pts[i * 2 + 1];
const ix = Math.round(px);
const iy = Math.round(py);
for (let dy = -r; dy <= r; dy++) {
const y = iy + dy;
if (y < 0 || y >= GRID) continue;
for (let dx = -r; dx <= r; dx++) {
const x = ix + dx;
if (x < 0 || x >= GRID) continue;
density[y * GRID + x] += SPLAT_STAMP[(dy + r) * size + (dx + r)];
}
}
}
}
// ---------- Triangle path of weights (w1,w2,w3), closed loop ----------
// Travel A->B->C->A along the simplex edges (each edge takes 1/3 of the loop).
function weightsAt(u) {
// u in [0,1)
const t = (u * 3) % 3;
const seg = Math.floor(t);
const f = t - seg;
if (seg === 0) return [1 - f, f, 0];
if (seg === 1) return [0, 1 - f, f];
return [f, 0, 1 - f];
}
// ---------- Precompute K barycenter density frames ----------
function precomputeFrames(ptsA, ptsB, ptsC) {
const buf = new Float32Array(GRID * GRID);
const tmp = new Float32Array(N_POINTS * 2);
for (let k = 0; k < K_FRAMES; k++) {
const u = k / K_FRAMES;
const w = weightsAt(u);
for (let i = 0; i < N_POINTS; i++) {
tmp[i * 2] = w[0] * ptsA[i * 2] + w[1] * ptsB[i * 2] + w[2] * ptsC[i * 2];
tmp[i * 2 + 1] = w[0] * ptsA[i * 2 + 1] + w[1] * ptsB[i * 2 + 1] + w[2] * ptsC[i * 2 + 1];
}
const d = new Float32Array(GRID * GRID);
splat(tmp, d);
frames[k] = d;
// record max for normalization
let m = 0;
for (let i = 0; i < d.length; i++) if (d[i] > m) m = d[i];
frameMax[k] = m || 1;
// also reuse buf var (unused)
}
void buf;
}
// ---------- Render a density field into a region of imgData ----------
// dstX,dstY,dstW,dstH in pixel coords. Uses nearest-neighbour for speed.
function blitDensity(density, dmax, dstX, dstY, dstW, dstH) {
for (let py = 0; py < dstH; py++) {
const sy = Math.min(GRID - 1, Math.floor(py * GRID / dstH));
for (let px = 0; px < dstW; px++) {
const sx = Math.min(GRID - 1, Math.floor(px * GRID / dstW));
const v = density[sy * GRID + sx] / dmax;
const t = Math.min(1, Math.max(0, v));
const c = viridis(t);
const i = ((dstY + py) * W + (dstX + px)) * 4;
imgData[i] = c[0] | 0;
imgData[i + 1] = c[1] | 0;
imgData[i + 2] = c[2] | 0;
imgData[i + 3] = 255;
}
}
}
// ---------- Crossfade blit: blend two density frames ----------
function blitCrossfade(dA, mA, dB, mB, alpha, dstX, dstY, dstW, dstH) {
for (let py = 0; py < dstH; py++) {
const sy = Math.min(GRID - 1, Math.floor(py * GRID / dstH));
for (let px = 0; px < dstW; px++) {
const sx = Math.min(GRID - 1, Math.floor(px * GRID / dstW));
const va = dA[sy * GRID + sx] / mA;
const vb = dB[sy * GRID + sx] / mB;
const v = (1 - alpha) * va + alpha * vb;
const t = Math.min(1, Math.max(0, v));
const c = viridis(t);
const i = ((dstY + py) * W + (dstX + px)) * 4;
imgData[i] = c[0] | 0;
imgData[i + 1] = c[1] | 0;
imgData[i + 2] = c[2] | 0;
imgData[i + 3] = 255;
}
}
}
function init({ canvas, ctx, width, height }) {
W = width;
H = height;
SPLAT_STAMP = buildStamp();
// Build shapes
const mA = shapeSquare();
const mB = shapeDisk();
const mC = shapeL();
// Sample points
let pA = samplePoints(mA, N_POINTS);
let pB = samplePoints(mB, N_POINTS);
let pC = samplePoints(mC, N_POINTS);
// Lloyd relax 2 iterations for nicer coverage
lloydRelax(pA, mA, 2);
lloydRelax(pB, mB, 2);
lloydRelax(pC, mC, 2);
// Sort into a common canonical order so the three sets correspond.
const sA = canonicalSort(pA);
const sB = canonicalSort(pB);
const sC = canonicalSort(pC);
// Precompute corner thumbs (density fields for the three shapes themselves).
cornerThumbs = [new Float32Array(GRID * GRID), new Float32Array(GRID * GRID), new Float32Array(GRID * GRID)];
cornerMax = new Float32Array(3);
splat(sA, cornerThumbs[0]);
splat(sB, cornerThumbs[1]);
splat(sC, cornerThumbs[2]);
for (let k = 0; k < 3; k++) {
let m = 0;
for (let i = 0; i < cornerThumbs[k].length; i++) if (cornerThumbs[k][i] > m) m = cornerThumbs[k][i];
cornerMax[k] = m || 1;
}
// Precompute the K barycenter frames along the loop.
frames = new Array(K_FRAMES);
frameMax = new Float32Array(K_FRAMES);
precomputeFrames(sA, sB, sC);
// ImageData for full-canvas blits.
imgBuf = ctx.createImageData(W, H);
imgData = imgBuf.data;
// Triangle widget corner positions are computed in tick (depend on layout).
triCorners = null;
// Paint initial background.
ctx.fillStyle = '#05060a';
ctx.fillRect(0, 0, W, H);
}
function tick({ ctx, dt, time, width, height }) {
// Handle resize: rebuild ImageData buffer only.
if (width !== W || height !== H) {
W = width;
H = height;
imgBuf = ctx.createImageData(W, H);
imgData = imgBuf.data;
}
timeAcc += dt;
const u = (timeAcc / LOOP_SECONDS) % 1;
// Pick the two neighbouring precomputed frames and crossfade.
const fIdx = u * K_FRAMES;
const k0 = Math.floor(fIdx) % K_FRAMES;
const k1 = (k0 + 1) % K_FRAMES;
const alpha = fIdx - Math.floor(fIdx);
// Main barycenter area: centered square taking up the bulk of canvas.
const margin = Math.floor(Math.min(W, H) * 0.06);
const mainSize = Math.min(W, H) - 2 * margin;
const mainX = Math.floor((W - mainSize) / 2);
const mainY = Math.floor((H - mainSize) / 2);
// Fill background first
for (let i = 0; i < imgData.length; i += 4) {
imgData[i] = 5;
imgData[i + 1] = 6;
imgData[i + 2] = 10;
imgData[i + 3] = 255;
}
// Blit the barycenter
blitCrossfade(
frames[k0], frameMax[k0],
frames[k1], frameMax[k1],
alpha,
mainX, mainY, mainSize, mainSize
);
// Triangle widget in the upper-right.
// Three corner shape thumbs at vertices of an equilateral triangle.
const widgetSize = Math.floor(Math.min(W, H) * 0.22);
const wx = W - widgetSize - margin;
const wy = margin;
const cx = wx + widgetSize / 2;
const cy = wy + widgetSize / 2;
const tr = widgetSize * 0.42;
// Vertices: top, bottom-right, bottom-left (corresponding to A, B, C)
const verts = [
[cx, cy - tr],
[cx + tr * Math.cos(Math.PI / 6), cy + tr * Math.sin(Math.PI / 6)],
[cx - tr * Math.cos(Math.PI / 6), cy + tr * Math.sin(Math.PI / 6)],
];
// Thumb size
const thumbS = Math.floor(widgetSize * 0.20);
// Draw thumbs into imgData
for (let v = 0; v < 3; v++) {
const tx = Math.round(verts[v][0] - thumbS / 2);
const ty = Math.round(verts[v][1] - thumbS / 2);
// Clip
const x0 = Math.max(0, tx);
const y0 = Math.max(0, ty);
const x1 = Math.min(W, tx + thumbS);
const y1 = Math.min(H, ty + thumbS);
for (let py = y0; py < y1; py++) {
const sy = Math.min(GRID - 1, Math.floor((py - ty) * GRID / thumbS));
for (let px = x0; px < x1; px++) {
const sx = Math.min(GRID - 1, Math.floor((px - tx) * GRID / thumbS));
const val = cornerThumbs[v][sy * GRID + sx] / cornerMax[v];
const t = Math.min(1, Math.max(0, val));
const c = viridis(t);
const i = (py * W + px) * 4;
imgData[i] = c[0] | 0;
imgData[i + 1] = c[1] | 0;
imgData[i + 2] = c[2] | 0;
imgData[i + 3] = 255;
}
}
}
// Push image data to canvas now (lines/dots drawn afterwards over).
ctx.putImageData(imgBuf, 0, 0);
// Draw triangle edges between thumb centers.
ctx.strokeStyle = 'rgba(220, 230, 240, 0.55)';
ctx.lineWidth = 1.2;
ctx.beginPath();
ctx.moveTo(verts[0][0], verts[0][1]);
ctx.lineTo(verts[1][0], verts[1][1]);
ctx.lineTo(verts[2][0], verts[2][1]);
ctx.closePath();
ctx.stroke();
// Current weights -> barycentric position inside the triangle.
const w = weightsAt(u);
const dotX = w[0] * verts[0][0] + w[1] * verts[1][0] + w[2] * verts[2][0];
const dotY = w[0] * verts[0][1] + w[1] * verts[1][1] + w[2] * verts[2][1];
// Glow halo
ctx.fillStyle = 'rgba(253, 231, 36, 0.18)';
ctx.beginPath();
ctx.arc(dotX, dotY, 9, 0, Math.PI * 2);
ctx.fill();
// Dot
ctx.fillStyle = 'rgba(253, 231, 36, 1.0)';
ctx.beginPath();
ctx.arc(dotX, dotY, 3.5, 0, Math.PI * 2);
ctx.fill();
// Tiny weight readout under the triangle.
ctx.fillStyle = 'rgba(220, 230, 240, 0.8)';
ctx.font = `${Math.max(10, Math.floor(widgetSize * 0.07))}px ui-monospace, monospace`;
ctx.textAlign = 'center';
const wTxt = `${w[0].toFixed(2)} ${w[1].toFixed(2)} ${w[2].toFixed(2)}`;
ctx.fillText(wTxt, cx, wy + widgetSize - 4);
}
Comments (0)
Log in to comment.