23
Sinkhorn Iterations: Entropic OT
paint μ on top, ν on right · drag Y to scrub ε
idle
466 lines · vanilla
view source
// Sinkhorn iterations for entropy-regularized 1D optimal transport.
// mu (top, blue) and nu (right, red) are histograms drawn on the axes of
// the unit square. The square holds the n x n coupling matrix P, started
// as K_ij = exp(-c_ij / eps) with quadratic cost c_ij = (i/n - j/n)^2,
// then alternately row-normalized to mu and column-normalized to nu.
// Rendered with a viridis heatmap. Inspired by Gabriel Peyre.
//
// Interactive:
// - Click/drag on the top strip to paint Gaussian bumps into mu.
// - Click/drag on the right strip to paint into nu.
// - Drag the cursor along Y to scrub eps in [0.005, 0.1] (higher Y = lower eps).
// - If idle for ~6s, auto-cycle through preset shape pairs.
const N = 96;
const STEP_MS = 150;
const HOLD_MS = 1800;
const IDLE_MS = 6000; // resume auto-cycle after this much idle time
const MORPH_MS = 900;
const SHAPES = ['gauss', 'two-modes', 'asym', 'sharp', 'wide'];
const EPS_MIN = 0.005;
const EPS_MAX = 0.1;
const PAINT_SIGMA_BINS = 4.5; // Gaussian bump width in bins
const PAINT_AMOUNT = 0.020; // mass added per drag-sample (pre-normalize)
let W, H;
let plotX, plotY, plotS; // unit-square plot region in canvas coords
let marginTop, marginRight; // strip thickness for histograms
let gap; // gap between plot and strips
let eps; // current entropy regularization
let K; // Float64Array length N*N — Gibbs kernel exp(-c/eps)
let P; // Float64Array length N*N — current coupling
let mu, nu; // Float64Array length N
let mu2, nu2; // scratch
let muT, nuT; // morph progress targets (next shape)
let iter;
let entropy;
let stepAcc; // ms since last sinkhorn step
let holdAcc; // ms since converged
let phase; // 'step' | 'hold' | 'morph' | 'user'
let morphAcc; // ms during morph
let shapeIdx;
let muPrev, nuPrev; // previous shapes for morph blend
let idleAcc; // ms since last user interaction
let userActive; // true while user is currently painting / scrubbing
let img; // ImageData for the heatmap (N x N)
let offC; // OffscreenCanvas N x N
let offCtx;
function shape(name, out) {
// produce a normalized histogram of length N on [0,1].
for (let i = 0; i < N; i++) {
const x = i / (N - 1);
let v = 0;
if (name === 'gauss') {
const s = 0.10;
v = Math.exp(-((x - 0.5) ** 2) / (2 * s * s));
} else if (name === 'two-modes') {
const s = 0.06;
v = Math.exp(-((x - 0.25) ** 2) / (2 * s * s))
+ 0.85 * Math.exp(-((x - 0.78) ** 2) / (2 * s * s));
} else if (name === 'asym') {
// skewed (gamma-like)
const a = 2.5, b = 6.0;
v = Math.pow(x, a) * Math.exp(-b * x);
} else if (name === 'sharp') {
const s = 0.045;
v = Math.exp(-((x - 0.6) ** 2) / (2 * s * s));
} else if (name === 'wide') {
const s = 0.22;
v = Math.exp(-((x - 0.4) ** 2) / (2 * s * s));
} else {
v = 1;
}
out[i] = v + 1e-6;
}
normalize(out);
}
function normalize(arr) {
let sum = 0;
for (let i = 0; i < N; i++) sum += arr[i];
if (sum > 0) for (let i = 0; i < N; i++) arr[i] /= sum;
}
function buildKernel() {
// K_ij = exp(-c_ij / eps), c_ij = (i/(N-1) - j/(N-1))^2 on [0,1]^2.
for (let i = 0; i < N; i++) {
const xi = i / (N - 1);
for (let j = 0; j < N; j++) {
const xj = j / (N - 1);
const c = (xi - xj) * (xi - xj);
K[i * N + j] = Math.exp(-c / eps);
}
}
}
function resetCouplingToK() {
for (let k = 0; k < N * N; k++) P[k] = K[k];
iter = 0;
}
function restartSinkhorn() {
// Called whenever mu, nu, or eps changes.
buildKernel();
resetCouplingToK();
entropy = computeEntropy();
stepAcc = 0;
holdAcc = 0;
}
function sinkhornStep() {
// Row normalize so row i sums to mu[i].
for (let i = 0; i < N; i++) {
let s = 0;
const off = i * N;
for (let j = 0; j < N; j++) s += P[off + j];
if (s > 0) {
const f = mu[i] / s;
for (let j = 0; j < N; j++) P[off + j] *= f;
}
}
// Column normalize so col j sums to nu[j].
const colSum = new Float64Array(N);
for (let i = 0; i < N; i++) {
const off = i * N;
for (let j = 0; j < N; j++) colSum[j] += P[off + j];
}
for (let j = 0; j < N; j++) {
if (colSum[j] > 0) {
const f = nu[j] / colSum[j];
for (let i = 0; i < N; i++) P[i * N + j] *= f;
}
}
iter++;
}
function computeEntropy() {
// H(P) = -sum P_ij log P_ij
let h = 0;
for (let k = 0; k < N * N; k++) {
const p = P[k];
if (p > 1e-30) h -= p * Math.log(p);
}
return h;
}
// viridis approximation (5-stop piecewise lerp through canonical samples).
const VIRIDIS = [
[68, 1, 84], // 0.0
[59, 82, 139], // 0.25
[33, 145, 140], // 0.5
[94, 201, 98], // 0.75
[253, 231, 37], // 1.0
];
function viridis(t, out) {
if (t < 0) t = 0; else if (t > 1) t = 1;
const s = t * (VIRIDIS.length - 1);
const i = Math.min(VIRIDIS.length - 2, Math.floor(s));
const f = s - i;
const a = VIRIDIS[i], b = VIRIDIS[i + 1];
out[0] = a[0] + (b[0] - a[0]) * f;
out[1] = a[1] + (b[1] - a[1]) * f;
out[2] = a[2] + (b[2] - a[2]) * f;
}
function layout() {
// Square plot, centered. Strips for mu (top) and nu (right) live just
// outside the plot. Leave room on the left/bottom for ticks/labels.
const padTop = 18;
const padBottom = 36;
const padLeft = 36;
const padRight = 18;
const stripT = Math.max(28, Math.min(W, H) * 0.08);
const stripGap = 6;
const availW = W - padLeft - padRight - stripT - stripGap;
const availH = H - padTop - padBottom - stripT - stripGap;
plotS = Math.max(40, Math.min(availW, availH));
plotX = padLeft + Math.max(0, (availW - plotS) * 0.5);
plotY = padTop + stripT + stripGap + Math.max(0, (availH - plotS) * 0.5);
marginTop = stripT;
marginRight = stripT;
gap = stripGap;
}
function init({ canvas, ctx, width, height }) {
W = width; H = height;
K = new Float64Array(N * N);
P = new Float64Array(N * N);
mu = new Float64Array(N);
nu = new Float64Array(N);
muPrev = new Float64Array(N);
nuPrev = new Float64Array(N);
muT = new Float64Array(N);
nuT = new Float64Array(N);
mu2 = new Float64Array(N);
nu2 = new Float64Array(N);
eps = 0.02;
shapeIdx = 0;
shape('gauss', mu);
shape('two-modes', nu);
shape('gauss', muPrev);
shape('two-modes', nuPrev);
restartSinkhorn();
morphAcc = 0;
phase = 'step';
idleAcc = 0;
userActive = false;
offC = new OffscreenCanvas(N, N);
offCtx = offC.getContext('2d');
img = offCtx.createImageData(N, N);
layout();
ctx.fillStyle = '#0a0b10';
ctx.fillRect(0, 0, W, H);
}
function pickNextShapes() {
// pick a different pair (not identical to current).
for (let tries = 0; tries < 10; tries++) {
const a = SHAPES[(Math.random() * SHAPES.length) | 0];
const b = SHAPES[(Math.random() * SHAPES.length) | 0];
if (a !== b) {
shape(a, muT);
shape(b, nuT);
return;
}
}
shape('gauss', muT);
shape('asym', nuT);
}
function paintHeatmap(ctx) {
// Normalize P for display (per-frame max so the structure is always visible).
let maxP = 0;
for (let k = 0; k < N * N; k++) if (P[k] > maxP) maxP = P[k];
if (maxP <= 0) maxP = 1;
const data = img.data;
const rgb = [0, 0, 0];
for (let i = 0; i < N; i++) {
for (let j = 0; j < N; j++) {
// Display row 0 at top, column 0 at left.
const v = P[i * N + j] / maxP;
// Mild gamma for contrast at the diagonal ridge.
const t = Math.pow(v, 0.55);
viridis(t, rgb);
const k = (i * N + j) * 4;
data[k] = rgb[0] | 0;
data[k + 1] = rgb[1] | 0;
data[k + 2] = rgb[2] | 0;
data[k + 3] = 255;
}
}
offCtx.putImageData(img, 0, 0);
ctx.imageSmoothingEnabled = false;
ctx.drawImage(offC, plotX, plotY, plotS, plotS);
}
function drawMuStrip(ctx) {
// mu on top, blue, height = marginTop, width = plotS.
const x0 = plotX, y0 = plotY - gap - marginTop;
ctx.fillStyle = userActive && lastPaintTarget === 'mu'
? 'rgba(40, 70, 110, 0.55)'
: 'rgba(30, 50, 80, 0.35)';
ctx.fillRect(x0, y0, plotS, marginTop);
let mx = 0;
for (let i = 0; i < N; i++) if (mu[i] > mx) mx = mu[i];
if (mx <= 0) mx = 1;
// filled area
ctx.beginPath();
ctx.moveTo(x0, y0 + marginTop);
for (let i = 0; i < N; i++) {
const x = x0 + (i / (N - 1)) * plotS;
const y = y0 + marginTop - (mu[i] / mx) * (marginTop - 4);
ctx.lineTo(x, y);
}
ctx.lineTo(x0 + plotS, y0 + marginTop);
ctx.closePath();
ctx.fillStyle = 'rgba(110, 170, 230, 0.28)';
ctx.fill();
// outline
ctx.beginPath();
for (let i = 0; i < N; i++) {
const x = x0 + (i / (N - 1)) * plotS;
const y = y0 + marginTop - (mu[i] / mx) * (marginTop - 4);
if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
}
ctx.strokeStyle = 'rgba(140, 195, 240, 0.85)';
ctx.lineWidth = 1.25;
ctx.stroke();
// small label
ctx.fillStyle = 'rgba(170, 200, 235, 0.75)';
ctx.font = '11px ui-monospace, monospace';
ctx.textBaseline = 'top';
ctx.textAlign = 'left';
ctx.fillText('mu (paint)', x0 + 4, y0 + 2);
}
function drawNuStrip(ctx) {
// nu on right, red, width = marginRight, height = plotS.
const x0 = plotX + plotS + gap, y0 = plotY;
ctx.fillStyle = userActive && lastPaintTarget === 'nu'
? 'rgba(110, 40, 40, 0.55)'
: 'rgba(80, 30, 30, 0.35)';
ctx.fillRect(x0, y0, marginRight, plotS);
let mx = 0;
for (let j = 0; j < N; j++) if (nu[j] > mx) mx = nu[j];
if (mx <= 0) mx = 1;
// filled area (extends rightward from the strip's left edge)
ctx.beginPath();
ctx.moveTo(x0, y0);
for (let j = 0; j < N; j++) {
const y = y0 + (j / (N - 1)) * plotS;
const x = x0 + (nu[j] / mx) * (marginRight - 4);
ctx.lineTo(x, y);
}
ctx.lineTo(x0, y0 + plotS);
ctx.closePath();
ctx.fillStyle = 'rgba(230, 110, 110, 0.26)';
ctx.fill();
ctx.beginPath();
for (let j = 0; j < N; j++) {
const y = y0 + (j / (N - 1)) * plotS;
const x = x0 + (nu[j] / mx) * (marginRight - 4);
if (j === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
}
ctx.strokeStyle = 'rgba(240, 150, 150, 0.85)';
ctx.lineWidth = 1.25;
ctx.stroke();
ctx.fillStyle = 'rgba(235, 180, 180, 0.75)';
ctx.font = '11px ui-monospace, monospace';
ctx.textBaseline = 'top';
ctx.textAlign = 'left';
ctx.fillText('nu (paint)', x0 + 4, y0 + 2);
}
function drawFrame(ctx) {
ctx.strokeStyle = 'rgba(180, 190, 210, 0.35)';
ctx.lineWidth = 1;
ctx.strokeRect(plotX + 0.5, plotY + 0.5, plotS - 1, plotS - 1);
}
function drawHud(ctx) {
ctx.fillStyle = 'rgba(220, 225, 240, 0.85)';
ctx.font = '12px ui-monospace, monospace';
ctx.textBaseline = 'alphabetic';
ctx.textAlign = 'left';
const y = plotY + plotS + 22;
const txt1 = `iter ${iter.toString().padStart(2, '0')}`;
const txt2 = `H(P) ${entropy.toFixed(3)}`;
const txt3 = `eps ${eps.toFixed(4)}`;
ctx.fillText(txt1, plotX, y);
ctx.fillText(txt2, plotX + 80, y);
ctx.fillText(txt3, plotX + 180, y);
// tiny eps slider hint on the right side of the plot
ctx.fillStyle = 'rgba(150, 160, 185, 0.6)';
ctx.font = '10px ui-monospace, monospace';
ctx.textAlign = 'right';
ctx.fillText('drag Y: eps', plotX + plotS, y);
}
// ----- Interaction --------------------------------------------------------
let lastPaintTarget = null; // 'mu' | 'nu' | null
let prevMouseDown = false;
let dragCount = 0; // bookkeeping for eps scrub during a drag
function addBump(arr, centerBin, amount) {
const sigma = PAINT_SIGMA_BINS;
const twoSig2 = 2 * sigma * sigma;
// 3-sigma window
const lo = Math.max(0, Math.floor(centerBin - 3 * sigma));
const hi = Math.min(N - 1, Math.ceil(centerBin + 3 * sigma));
for (let i = lo; i <= hi; i++) {
const d = i - centerBin;
arr[i] += amount * Math.exp(-(d * d) / twoSig2);
}
}
function muStripRect() {
return {
x0: plotX,
y0: plotY - gap - marginTop,
w: plotS,
h: marginTop,
};
}
function nuStripRect() {
return {
x0: plotX + plotS + gap,
y0: plotY,
w: marginRight,
h: plotS,
};
}
function pointInRect(px, py, r, pad) {
pad = pad || 0;
return px >= r.x0 - pad && px <= r.x0 + r.w + pad
&& py >= r.y0 - pad && py <= r.y0 + r.h + pad;
}
function handleInput(input, dtMs) {
const mx = input.mouseX, my = input.mouseY;
const down = !!input.mouseDown;
const muR = muStripRect();
const nuR = nuStripRect();
// eps scrub: map mouseY across the canvas to [EPS_MIN, EPS_MAX].
// Higher Y (further down on screen) -> lower eps.
// Only adjust eps while the user is actively interacting (mouseDown
// OR hovering over the plot region without painting).
let epsChanged = false;
if (down && my >= 0 && my <= H) {
const tY = Math.min(1, Math.max(0, my / H));
// invert so top of canvas = max eps (blurry), bottom = min eps (sharp).
const newEps = EPS_MAX + (EPS_MIN - EPS_MAX) * tY;
if (Math.abs(newEps - eps) > 1e-5) {
eps = newEps;
epsChanged = true;
}
}
let painted = false;
if (down) {
// Detect target. Use a small pad so edge clicks still register.
const inMu = pointInRect(mx, my, muR, 4);
const inNu = pointInRect(mx, my, nuR, 4);
if (inMu) {
const t = (mx - muR.x0) / muR.w;
const bin = Math.max(0, Math.min(N - 1, t * (N - 1)));
addBump(mu, bin, PAINT_AMOUNT);
normalize(mu);
lastPaintTarget = 'mu';
painted = true;
} else if (inNu) {
const t = (my - nuR.y0) / nuR.h;
const bin = Math.max(0, Math.min(N - 1, t * (N - 1)));
addBump(nu, bin, PAINT_AMOUNT);
normalize(nu);
lastPaintTarget = 'nu';
painted = true;
}
} else {
lastPaintTarget = null;
}
if (down) dragCount++; else dragCount = 0;
const interacted = painted || epsChanged;
if (interacted) {
idleAcc = 0;
userActive = true;
phase = 'user';
// Whenever the user touches mu/nu/eps, drain the old iterates and
// restart Sinkhorn from K. Painting and eps scrubbing both change
// the kernel-marginal pairing, so this is the honest reset.
restartSinkhorn();
} else {
idleAcc += dtMs;
if (idleAcc < IDLE_MS) {
userActive = true;
} else {
userActive = false;
}
}
// Drop "edge" tracking of click-vs-drag.
prevMouseDown = down;
}
// ----- Main loop ----------------------------------------------------------
function tick({ ctx, dt, width, height, input }) {
if (width !== W || height !== H) {
W = width; H = height;
layout();
}
// background
ctx.fillStyle = '#0a0b10';
ctx.fillRect(0, 0, W, H);
const dtMs = dt * 1000;
handleInput(input, dtMs);
// Determine whether to auto-cycle. We only run morph/hold logic when
// the user has been idle long enough to not be in the middle of edits.
const idle = idleAcc >= IDLE_MS;
if (phase === 'user') {
// While the user is recently active, just iterate Sinkhorn on the
// current mu/nu/eps. If they go idle long enough, hand back to the
// auto-cycle.
stepAcc += dtMs;
while (stepAcc >= STEP_MS && iter < 80) {
stepAcc -= STEP_MS;
sinkhornStep();
entropy = computeEntropy();
}
if (idle) {
// Restart the auto-cycle gracefully: snapshot current mu/nu as the
// "previous" of an upcoming morph.
for (let i = 0; i < N; i++) {
muPrev[i] = mu[i];
nuPrev[i] = nu[i];
}
pickNextShapes();
morphAcc = 0;
phase = 'morph';
}
} else if (phase === 'step') {
stepAcc += dtMs;
while (stepAcc >= STEP_MS && iter < 60) {
stepAcc -= STEP_MS;
sinkhornStep();
entropy = computeEntropy();
}
if (iter >= 40) {
phase = 'hold';
holdAcc = 0;
}
} else if (phase === 'hold') {
holdAcc += dtMs;
if (holdAcc >= HOLD_MS) {
// start morph to new shapes
for (let i = 0; i < N; i++) {
muPrev[i] = mu[i];
nuPrev[i] = nu[i];
}
pickNextShapes();
phase = 'morph';
morphAcc = 0;
}
} else if (phase === 'morph') {
morphAcc += dtMs;
const t = Math.min(1, morphAcc / MORPH_MS);
const s = t * t * (3 - 2 * t); // smoothstep
let suMu = 0, suNu = 0;
for (let i = 0; i < N; i++) {
mu[i] = (1 - s) * muPrev[i] + s * muT[i];
nu[i] = (1 - s) * nuPrev[i] + s * nuT[i];
suMu += mu[i];
suNu += nu[i];
}
if (suMu > 0) for (let i = 0; i < N; i++) mu[i] /= suMu;
if (suNu > 0) for (let i = 0; i < N; i++) nu[i] /= suNu;
if (t >= 1) {
restartSinkhorn();
phase = 'step';
}
}
paintHeatmap(ctx);
drawFrame(ctx);
drawMuStrip(ctx);
drawNuStrip(ctx);
drawHud(ctx);
}
Comments (0)
Log in to comment.