23

Sinkhorn Iterations: Entropic OT

paint μ on top, ν on right · drag Y to scrub ε

Entropy-regularized optimal transport finds the cheapest way to move mass from a source histogram (top, blue) to a target (right, red) under a quadratic ground cost , while preferring couplings of higher entropy. Concretely it solves , whose optimum has the form with the Gibbs kernel . The Sinkhorn algorithm initializes and then alternates row- and column-rescalings so the marginals match: , then . Equivalently it iterates , . The heatmap shows in viridis at ; you can watch the diagonal ridge sharpen as iterations converge to the entropic OT plan, then the marginals morph and the solver restarts. Inspired by Gabriel Peyré's optimal-transport visualizations.

idle
466 lines · vanilla
view source
// Sinkhorn iterations for entropy-regularized 1D optimal transport.
// mu (top, blue) and nu (right, red) are histograms drawn on the axes of
// the unit square. The square holds the n x n coupling matrix P, started
// as K_ij = exp(-c_ij / eps) with quadratic cost c_ij = (i/n - j/n)^2,
// then alternately row-normalized to mu and column-normalized to nu.
// Rendered with a viridis heatmap. Inspired by Gabriel Peyre.
//
// Interactive:
//   - Click/drag on the top strip to paint Gaussian bumps into mu.
//   - Click/drag on the right strip to paint into nu.
//   - Drag the cursor along Y to scrub eps in [0.005, 0.1] (higher Y = lower eps).
//   - If idle for ~6s, auto-cycle through preset shape pairs.

const N = 96;
const STEP_MS = 150;
const HOLD_MS = 1800;
const IDLE_MS = 6000;            // resume auto-cycle after this much idle time
const MORPH_MS = 900;
const SHAPES = ['gauss', 'two-modes', 'asym', 'sharp', 'wide'];

const EPS_MIN = 0.005;
const EPS_MAX = 0.1;
const PAINT_SIGMA_BINS = 4.5;    // Gaussian bump width in bins
const PAINT_AMOUNT = 0.020;      // mass added per drag-sample (pre-normalize)

let W, H;
let plotX, plotY, plotS;       // unit-square plot region in canvas coords
let marginTop, marginRight;    // strip thickness for histograms
let gap;                       // gap between plot and strips

let eps;        // current entropy regularization
let K;          // Float64Array length N*N — Gibbs kernel exp(-c/eps)
let P;          // Float64Array length N*N — current coupling
let mu, nu;     // Float64Array length N
let mu2, nu2;   // scratch
let muT, nuT;   // morph progress targets (next shape)

let iter;
let entropy;
let stepAcc;        // ms since last sinkhorn step
let holdAcc;        // ms since converged
let phase;          // 'step' | 'hold' | 'morph' | 'user'
let morphAcc;       // ms during morph
let shapeIdx;
let muPrev, nuPrev; // previous shapes for morph blend
let idleAcc;        // ms since last user interaction
let userActive;     // true while user is currently painting / scrubbing

let img;            // ImageData for the heatmap (N x N)
let offC;           // OffscreenCanvas N x N
let offCtx;

function shape(name, out) {
  // produce a normalized histogram of length N on [0,1].
  for (let i = 0; i < N; i++) {
    const x = i / (N - 1);
    let v = 0;
    if (name === 'gauss') {
      const s = 0.10;
      v = Math.exp(-((x - 0.5) ** 2) / (2 * s * s));
    } else if (name === 'two-modes') {
      const s = 0.06;
      v = Math.exp(-((x - 0.25) ** 2) / (2 * s * s))
        + 0.85 * Math.exp(-((x - 0.78) ** 2) / (2 * s * s));
    } else if (name === 'asym') {
      // skewed (gamma-like)
      const a = 2.5, b = 6.0;
      v = Math.pow(x, a) * Math.exp(-b * x);
    } else if (name === 'sharp') {
      const s = 0.045;
      v = Math.exp(-((x - 0.6) ** 2) / (2 * s * s));
    } else if (name === 'wide') {
      const s = 0.22;
      v = Math.exp(-((x - 0.4) ** 2) / (2 * s * s));
    } else {
      v = 1;
    }
    out[i] = v + 1e-6;
  }
  normalize(out);
}

function normalize(arr) {
  let sum = 0;
  for (let i = 0; i < N; i++) sum += arr[i];
  if (sum > 0) for (let i = 0; i < N; i++) arr[i] /= sum;
}

function buildKernel() {
  // K_ij = exp(-c_ij / eps), c_ij = (i/(N-1) - j/(N-1))^2 on [0,1]^2.
  for (let i = 0; i < N; i++) {
    const xi = i / (N - 1);
    for (let j = 0; j < N; j++) {
      const xj = j / (N - 1);
      const c = (xi - xj) * (xi - xj);
      K[i * N + j] = Math.exp(-c / eps);
    }
  }
}

function resetCouplingToK() {
  for (let k = 0; k < N * N; k++) P[k] = K[k];
  iter = 0;
}

function restartSinkhorn() {
  // Called whenever mu, nu, or eps changes.
  buildKernel();
  resetCouplingToK();
  entropy = computeEntropy();
  stepAcc = 0;
  holdAcc = 0;
}

function sinkhornStep() {
  // Row normalize so row i sums to mu[i].
  for (let i = 0; i < N; i++) {
    let s = 0;
    const off = i * N;
    for (let j = 0; j < N; j++) s += P[off + j];
    if (s > 0) {
      const f = mu[i] / s;
      for (let j = 0; j < N; j++) P[off + j] *= f;
    }
  }
  // Column normalize so col j sums to nu[j].
  const colSum = new Float64Array(N);
  for (let i = 0; i < N; i++) {
    const off = i * N;
    for (let j = 0; j < N; j++) colSum[j] += P[off + j];
  }
  for (let j = 0; j < N; j++) {
    if (colSum[j] > 0) {
      const f = nu[j] / colSum[j];
      for (let i = 0; i < N; i++) P[i * N + j] *= f;
    }
  }
  iter++;
}

function computeEntropy() {
  // H(P) = -sum P_ij log P_ij
  let h = 0;
  for (let k = 0; k < N * N; k++) {
    const p = P[k];
    if (p > 1e-30) h -= p * Math.log(p);
  }
  return h;
}

// viridis approximation (5-stop piecewise lerp through canonical samples).
const VIRIDIS = [
  [68, 1, 84],     // 0.0
  [59, 82, 139],   // 0.25
  [33, 145, 140],  // 0.5
  [94, 201, 98],   // 0.75
  [253, 231, 37],  // 1.0
];
function viridis(t, out) {
  if (t < 0) t = 0; else if (t > 1) t = 1;
  const s = t * (VIRIDIS.length - 1);
  const i = Math.min(VIRIDIS.length - 2, Math.floor(s));
  const f = s - i;
  const a = VIRIDIS[i], b = VIRIDIS[i + 1];
  out[0] = a[0] + (b[0] - a[0]) * f;
  out[1] = a[1] + (b[1] - a[1]) * f;
  out[2] = a[2] + (b[2] - a[2]) * f;
}

function layout() {
  // Square plot, centered. Strips for mu (top) and nu (right) live just
  // outside the plot. Leave room on the left/bottom for ticks/labels.
  const padTop = 18;
  const padBottom = 36;
  const padLeft = 36;
  const padRight = 18;
  const stripT = Math.max(28, Math.min(W, H) * 0.08);
  const stripGap = 6;

  const availW = W - padLeft - padRight - stripT - stripGap;
  const availH = H - padTop - padBottom - stripT - stripGap;
  plotS = Math.max(40, Math.min(availW, availH));
  plotX = padLeft + Math.max(0, (availW - plotS) * 0.5);
  plotY = padTop + stripT + stripGap + Math.max(0, (availH - plotS) * 0.5);
  marginTop = stripT;
  marginRight = stripT;
  gap = stripGap;
}

function init({ canvas, ctx, width, height }) {
  W = width; H = height;
  K = new Float64Array(N * N);
  P = new Float64Array(N * N);
  mu = new Float64Array(N);
  nu = new Float64Array(N);
  muPrev = new Float64Array(N);
  nuPrev = new Float64Array(N);
  muT = new Float64Array(N);
  nuT = new Float64Array(N);
  mu2 = new Float64Array(N);
  nu2 = new Float64Array(N);

  eps = 0.02;

  shapeIdx = 0;
  shape('gauss', mu);
  shape('two-modes', nu);
  shape('gauss', muPrev);
  shape('two-modes', nuPrev);

  restartSinkhorn();

  morphAcc = 0;
  phase = 'step';
  idleAcc = 0;
  userActive = false;

  offC = new OffscreenCanvas(N, N);
  offCtx = offC.getContext('2d');
  img = offCtx.createImageData(N, N);

  layout();

  ctx.fillStyle = '#0a0b10';
  ctx.fillRect(0, 0, W, H);
}

function pickNextShapes() {
  // pick a different pair (not identical to current).
  for (let tries = 0; tries < 10; tries++) {
    const a = SHAPES[(Math.random() * SHAPES.length) | 0];
    const b = SHAPES[(Math.random() * SHAPES.length) | 0];
    if (a !== b) {
      shape(a, muT);
      shape(b, nuT);
      return;
    }
  }
  shape('gauss', muT);
  shape('asym', nuT);
}

function paintHeatmap(ctx) {
  // Normalize P for display (per-frame max so the structure is always visible).
  let maxP = 0;
  for (let k = 0; k < N * N; k++) if (P[k] > maxP) maxP = P[k];
  if (maxP <= 0) maxP = 1;
  const data = img.data;
  const rgb = [0, 0, 0];
  for (let i = 0; i < N; i++) {
    for (let j = 0; j < N; j++) {
      // Display row 0 at top, column 0 at left.
      const v = P[i * N + j] / maxP;
      // Mild gamma for contrast at the diagonal ridge.
      const t = Math.pow(v, 0.55);
      viridis(t, rgb);
      const k = (i * N + j) * 4;
      data[k] = rgb[0] | 0;
      data[k + 1] = rgb[1] | 0;
      data[k + 2] = rgb[2] | 0;
      data[k + 3] = 255;
    }
  }
  offCtx.putImageData(img, 0, 0);
  ctx.imageSmoothingEnabled = false;
  ctx.drawImage(offC, plotX, plotY, plotS, plotS);
}

function drawMuStrip(ctx) {
  // mu on top, blue, height = marginTop, width = plotS.
  const x0 = plotX, y0 = plotY - gap - marginTop;
  ctx.fillStyle = userActive && lastPaintTarget === 'mu'
    ? 'rgba(40, 70, 110, 0.55)'
    : 'rgba(30, 50, 80, 0.35)';
  ctx.fillRect(x0, y0, plotS, marginTop);

  let mx = 0;
  for (let i = 0; i < N; i++) if (mu[i] > mx) mx = mu[i];
  if (mx <= 0) mx = 1;

  // filled area
  ctx.beginPath();
  ctx.moveTo(x0, y0 + marginTop);
  for (let i = 0; i < N; i++) {
    const x = x0 + (i / (N - 1)) * plotS;
    const y = y0 + marginTop - (mu[i] / mx) * (marginTop - 4);
    ctx.lineTo(x, y);
  }
  ctx.lineTo(x0 + plotS, y0 + marginTop);
  ctx.closePath();
  ctx.fillStyle = 'rgba(110, 170, 230, 0.28)';
  ctx.fill();

  // outline
  ctx.beginPath();
  for (let i = 0; i < N; i++) {
    const x = x0 + (i / (N - 1)) * plotS;
    const y = y0 + marginTop - (mu[i] / mx) * (marginTop - 4);
    if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
  }
  ctx.strokeStyle = 'rgba(140, 195, 240, 0.85)';
  ctx.lineWidth = 1.25;
  ctx.stroke();

  // small label
  ctx.fillStyle = 'rgba(170, 200, 235, 0.75)';
  ctx.font = '11px ui-monospace, monospace';
  ctx.textBaseline = 'top';
  ctx.textAlign = 'left';
  ctx.fillText('mu  (paint)', x0 + 4, y0 + 2);
}

function drawNuStrip(ctx) {
  // nu on right, red, width = marginRight, height = plotS.
  const x0 = plotX + plotS + gap, y0 = plotY;
  ctx.fillStyle = userActive && lastPaintTarget === 'nu'
    ? 'rgba(110, 40, 40, 0.55)'
    : 'rgba(80, 30, 30, 0.35)';
  ctx.fillRect(x0, y0, marginRight, plotS);

  let mx = 0;
  for (let j = 0; j < N; j++) if (nu[j] > mx) mx = nu[j];
  if (mx <= 0) mx = 1;

  // filled area (extends rightward from the strip's left edge)
  ctx.beginPath();
  ctx.moveTo(x0, y0);
  for (let j = 0; j < N; j++) {
    const y = y0 + (j / (N - 1)) * plotS;
    const x = x0 + (nu[j] / mx) * (marginRight - 4);
    ctx.lineTo(x, y);
  }
  ctx.lineTo(x0, y0 + plotS);
  ctx.closePath();
  ctx.fillStyle = 'rgba(230, 110, 110, 0.26)';
  ctx.fill();

  ctx.beginPath();
  for (let j = 0; j < N; j++) {
    const y = y0 + (j / (N - 1)) * plotS;
    const x = x0 + (nu[j] / mx) * (marginRight - 4);
    if (j === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
  }
  ctx.strokeStyle = 'rgba(240, 150, 150, 0.85)';
  ctx.lineWidth = 1.25;
  ctx.stroke();

  ctx.fillStyle = 'rgba(235, 180, 180, 0.75)';
  ctx.font = '11px ui-monospace, monospace';
  ctx.textBaseline = 'top';
  ctx.textAlign = 'left';
  ctx.fillText('nu  (paint)', x0 + 4, y0 + 2);
}

function drawFrame(ctx) {
  ctx.strokeStyle = 'rgba(180, 190, 210, 0.35)';
  ctx.lineWidth = 1;
  ctx.strokeRect(plotX + 0.5, plotY + 0.5, plotS - 1, plotS - 1);
}

function drawHud(ctx) {
  ctx.fillStyle = 'rgba(220, 225, 240, 0.85)';
  ctx.font = '12px ui-monospace, monospace';
  ctx.textBaseline = 'alphabetic';
  ctx.textAlign = 'left';
  const y = plotY + plotS + 22;
  const txt1 = `iter ${iter.toString().padStart(2, '0')}`;
  const txt2 = `H(P) ${entropy.toFixed(3)}`;
  const txt3 = `eps ${eps.toFixed(4)}`;
  ctx.fillText(txt1, plotX, y);
  ctx.fillText(txt2, plotX + 80, y);
  ctx.fillText(txt3, plotX + 180, y);

  // tiny eps slider hint on the right side of the plot
  ctx.fillStyle = 'rgba(150, 160, 185, 0.6)';
  ctx.font = '10px ui-monospace, monospace';
  ctx.textAlign = 'right';
  ctx.fillText('drag Y: eps', plotX + plotS, y);
}

// ----- Interaction --------------------------------------------------------

let lastPaintTarget = null;  // 'mu' | 'nu' | null
let prevMouseDown = false;
let dragCount = 0;           // bookkeeping for eps scrub during a drag

function addBump(arr, centerBin, amount) {
  const sigma = PAINT_SIGMA_BINS;
  const twoSig2 = 2 * sigma * sigma;
  // 3-sigma window
  const lo = Math.max(0, Math.floor(centerBin - 3 * sigma));
  const hi = Math.min(N - 1, Math.ceil(centerBin + 3 * sigma));
  for (let i = lo; i <= hi; i++) {
    const d = i - centerBin;
    arr[i] += amount * Math.exp(-(d * d) / twoSig2);
  }
}

function muStripRect() {
  return {
    x0: plotX,
    y0: plotY - gap - marginTop,
    w: plotS,
    h: marginTop,
  };
}

function nuStripRect() {
  return {
    x0: plotX + plotS + gap,
    y0: plotY,
    w: marginRight,
    h: plotS,
  };
}

function pointInRect(px, py, r, pad) {
  pad = pad || 0;
  return px >= r.x0 - pad && px <= r.x0 + r.w + pad
      && py >= r.y0 - pad && py <= r.y0 + r.h + pad;
}

function handleInput(input, dtMs) {
  const mx = input.mouseX, my = input.mouseY;
  const down = !!input.mouseDown;
  const muR = muStripRect();
  const nuR = nuStripRect();

  // eps scrub: map mouseY across the canvas to [EPS_MIN, EPS_MAX].
  // Higher Y (further down on screen) -> lower eps.
  // Only adjust eps while the user is actively interacting (mouseDown
  // OR hovering over the plot region without painting).
  let epsChanged = false;
  if (down && my >= 0 && my <= H) {
    const tY = Math.min(1, Math.max(0, my / H));
    // invert so top of canvas = max eps (blurry), bottom = min eps (sharp).
    const newEps = EPS_MAX + (EPS_MIN - EPS_MAX) * tY;
    if (Math.abs(newEps - eps) > 1e-5) {
      eps = newEps;
      epsChanged = true;
    }
  }

  let painted = false;

  if (down) {
    // Detect target. Use a small pad so edge clicks still register.
    const inMu = pointInRect(mx, my, muR, 4);
    const inNu = pointInRect(mx, my, nuR, 4);
    if (inMu) {
      const t = (mx - muR.x0) / muR.w;
      const bin = Math.max(0, Math.min(N - 1, t * (N - 1)));
      addBump(mu, bin, PAINT_AMOUNT);
      normalize(mu);
      lastPaintTarget = 'mu';
      painted = true;
    } else if (inNu) {
      const t = (my - nuR.y0) / nuR.h;
      const bin = Math.max(0, Math.min(N - 1, t * (N - 1)));
      addBump(nu, bin, PAINT_AMOUNT);
      normalize(nu);
      lastPaintTarget = 'nu';
      painted = true;
    }
  } else {
    lastPaintTarget = null;
  }

  if (down) dragCount++; else dragCount = 0;

  const interacted = painted || epsChanged;
  if (interacted) {
    idleAcc = 0;
    userActive = true;
    phase = 'user';
    // Whenever the user touches mu/nu/eps, drain the old iterates and
    // restart Sinkhorn from K. Painting and eps scrubbing both change
    // the kernel-marginal pairing, so this is the honest reset.
    restartSinkhorn();
  } else {
    idleAcc += dtMs;
    if (idleAcc < IDLE_MS) {
      userActive = true;
    } else {
      userActive = false;
    }
  }

  // Drop "edge" tracking of click-vs-drag.
  prevMouseDown = down;
}

// ----- Main loop ----------------------------------------------------------

function tick({ ctx, dt, width, height, input }) {
  if (width !== W || height !== H) {
    W = width; H = height;
    layout();
  }

  // background
  ctx.fillStyle = '#0a0b10';
  ctx.fillRect(0, 0, W, H);

  const dtMs = dt * 1000;

  handleInput(input, dtMs);

  // Determine whether to auto-cycle. We only run morph/hold logic when
  // the user has been idle long enough to not be in the middle of edits.
  const idle = idleAcc >= IDLE_MS;

  if (phase === 'user') {
    // While the user is recently active, just iterate Sinkhorn on the
    // current mu/nu/eps. If they go idle long enough, hand back to the
    // auto-cycle.
    stepAcc += dtMs;
    while (stepAcc >= STEP_MS && iter < 80) {
      stepAcc -= STEP_MS;
      sinkhornStep();
      entropy = computeEntropy();
    }
    if (idle) {
      // Restart the auto-cycle gracefully: snapshot current mu/nu as the
      // "previous" of an upcoming morph.
      for (let i = 0; i < N; i++) {
        muPrev[i] = mu[i];
        nuPrev[i] = nu[i];
      }
      pickNextShapes();
      morphAcc = 0;
      phase = 'morph';
    }
  } else if (phase === 'step') {
    stepAcc += dtMs;
    while (stepAcc >= STEP_MS && iter < 60) {
      stepAcc -= STEP_MS;
      sinkhornStep();
      entropy = computeEntropy();
    }
    if (iter >= 40) {
      phase = 'hold';
      holdAcc = 0;
    }
  } else if (phase === 'hold') {
    holdAcc += dtMs;
    if (holdAcc >= HOLD_MS) {
      // start morph to new shapes
      for (let i = 0; i < N; i++) {
        muPrev[i] = mu[i];
        nuPrev[i] = nu[i];
      }
      pickNextShapes();
      phase = 'morph';
      morphAcc = 0;
    }
  } else if (phase === 'morph') {
    morphAcc += dtMs;
    const t = Math.min(1, morphAcc / MORPH_MS);
    const s = t * t * (3 - 2 * t);  // smoothstep
    let suMu = 0, suNu = 0;
    for (let i = 0; i < N; i++) {
      mu[i] = (1 - s) * muPrev[i] + s * muT[i];
      nu[i] = (1 - s) * nuPrev[i] + s * nuT[i];
      suMu += mu[i];
      suNu += nu[i];
    }
    if (suMu > 0) for (let i = 0; i < N; i++) mu[i] /= suMu;
    if (suNu > 0) for (let i = 0; i < N; i++) nu[i] /= suNu;

    if (t >= 1) {
      restartSinkhorn();
      phase = 'step';
    }
  }

  paintHeatmap(ctx);
  drawFrame(ctx);
  drawMuStrip(ctx);
  drawNuStrip(ctx);
  drawHud(ctx);
}

Comments (0)

Log in to comment.