15

SVD: Best Rank-k Approximation

drag X to scrub k · click to swap image

Eckart-Young: the closest rank- matrix to in Frobenius norm is , where is the singular value decomposition with . Equivalently, truncating the SVD at terms minimizes over all of rank at most , and the residual energy is exactly . Here is a grayscale image. Left: the current reconstruction . Right: the singular spectrum, with the first bars filled and the rest dimmed. As ramps up, watch the image sharpen and the spectrum fill in; the tail you leave behind is exactly what you've thrown away. Inspired by Gabriel Peyré's linear-algebra-as-image visualizations.

idle
475 lines · vanilla
view source
// SVD as best rank-k approximation of an image.
// Eckart-Young: A_k = sum_{i=1..k} sigma_i u_i v_i^T minimizes ||A - A_k||_F
// over all rank-k matrices. Drag horizontally to scrub k; click to swap the
// source image. Idle for ~3s and the auto-sweep takes over. Inspired by Gabriel Peyre.

const N = 64;             // image is N x N
const KMAX = 48;          // animate up to this k (<= N)
const RAMP_SEC = 10;      // seconds for k: 1 -> KMAX
const HOLD_SEC = 2;       // seconds to hold at KMAX before snapping back
const IDLE_SEC = 3;       // seconds of no mouse motion before auto-sweep resumes

let A;                    // Float64Array length N*N, original image (row-major) for the active image
let U;                    // N x N, columns are left singular vectors (active)
let V;                    // N x N, columns are right singular vectors (active)
let S;                    // singular values, length N, sorted desc (active)
let imgBuf;               // OffscreenCanvas for the reconstruction
let imgCtx;
let imgData;              // ImageData backing imgBuf
let recon;                // Float64Array length N*N, current A_k

let W = 0, H = 0;
let tElapsed = 0;         // wall-clock for the auto-sweep phase
let currentK = 1;
let lastK = -1;
let lastErr = 0;

// Interaction state
let imageIdx = 0;             // which procedural image is active
const NUM_IMAGES = 4;
const svdCache = [];          // per-image cached { A, S, U, V }
let lastMouseX = -1;
let lastMouseY = -1;
let idleTimer = 0;            // seconds since last mouse motion
let scrubActive = false;      // true while user is driving k via mouseX
let lastMouseInside = false;

// ---------- procedural images ----------
// Each builder returns Float64Array(N*N) in [0,1], row-major.

function buildImageM() {
  // Stylized capital "M" + soft diagonal stripe + radial vignette.
  const a = new Float64Array(N * N);
  const cx = N / 2;
  for (let y = 0; y < N; y++) {
    for (let x = 0; x < N; x++) {
      let v = 0.18 + 0.10 * Math.sin((x + y) * 0.55);
      const dx = (x - cx) / N, dy = (y - cx) / N;
      v -= 0.10 * (dx * dx + dy * dy);

      const left = 12, right = N - 12, top = 12, bot = N - 12;
      if (x >= left && x <= right && y >= top && y <= bot) {
        const u = (x - left) / (right - left);
        const w = (y - top) / (bot - top);
        let onGlyph = false;
        if (u < 0.18 || u > 0.82) onGlyph = true;
        if (!onGlyph) {
          const thick = 0.10;
          const t1 = (u - 0.18) / (0.50 - 0.18);
          if (t1 >= 0 && t1 <= 1) {
            const wExp = t1 * 0.70;
            if (Math.abs(w - wExp) < thick) onGlyph = true;
          }
          if (!onGlyph) {
            const t2 = (u - 0.50) / (0.82 - 0.50);
            if (t2 >= 0 && t2 <= 1) {
              const wExp = 0.70 - t2 * 0.70;
              if (Math.abs(w - wExp) < thick) onGlyph = true;
            }
          }
        }
        if (onGlyph) v = 0.92;
      }

      if (v < 0) v = 0;
      if (v > 1) v = 1;
      a[y * N + x] = v;
    }
  }
  return a;
}

function buildImageCircleOnStripes() {
  // Horizontal stripes (rich low-rank structure) + a filled disk on top.
  const a = new Float64Array(N * N);
  const cx = N / 2 - 0.5, cy = N / 2 - 0.5;
  const r = N * 0.30;
  for (let y = 0; y < N; y++) {
    // periodic stripe in y: one sigma per harmonic
    const stripe = 0.32 + 0.28 * Math.sin(y * 0.55);
    for (let x = 0; x < N; x++) {
      let v = stripe + 0.06 * Math.cos(x * 0.18);
      const dx = x - cx, dy = y - cy;
      if (dx * dx + dy * dy <= r * r) v = 0.93;
      if (v < 0) v = 0;
      if (v > 1) v = 1;
      a[y * N + x] = v;
    }
  }
  return a;
}

function buildImageOmega() {
  // Stylized capital Omega: a near-circle open at the bottom with two feet.
  const a = new Float64Array(N * N);
  const cx = N / 2 - 0.5, cy = N / 2 + 2;
  for (let y = 0; y < N; y++) {
    for (let x = 0; x < N; x++) {
      // soft warm-cool background gradient
      let v = 0.20 + 0.18 * (y / N);
      v += 0.05 * Math.sin(x * 0.30 + y * 0.20);

      const dx = x - cx, dy = y - cy;
      const rr = Math.sqrt(dx * dx + dy * dy);
      const ang = Math.atan2(dy, dx); // -pi..pi, with bottom = +pi/2

      // Omega arc: ring of radius ~ R, thickness ~ T, open in the bottom 50deg.
      const R = N * 0.30;
      const T = 3.2;
      const openHalf = 0.45; // radians from straight-down where the ring breaks
      const angFromDown = Math.abs(ang - Math.PI / 2);
      const angWrap = Math.min(angFromDown, Math.abs(angFromDown - 2 * Math.PI));
      const onRing = Math.abs(rr - R) < T && angWrap > openHalf;
      let onGlyph = onRing;

      // Two feet at the bottom of the open ring.
      // Left foot: small horizontal bar near (cx - R*sin(openHalf), cy + R*cos(openHalf))
      const fy = cy + R * Math.cos(openHalf);
      const lx = cx - R * Math.sin(openHalf);
      const rx = cx + R * Math.sin(openHalf);
      const footHalfW = 5;
      const footHalfH = 2;
      if (Math.abs(y - fy) < footHalfH && Math.abs(x - lx) < footHalfW) onGlyph = true;
      if (Math.abs(y - fy) < footHalfH && Math.abs(x - rx) < footHalfW) onGlyph = true;

      if (onGlyph) v = 0.94;

      if (v < 0) v = 0;
      if (v > 1) v = 1;
      a[y * N + x] = v;
    }
  }
  return a;
}

function buildImageCheckerBite() {
  // 8x8 checkerboard with a circular "bite" taken out of the upper right.
  const a = new Float64Array(N * N);
  const cellsX = 8, cellsY = 8;
  const cx = N * 0.72, cy = N * 0.30;
  const r = N * 0.18;
  for (let y = 0; y < N; y++) {
    for (let x = 0; x < N; x++) {
      const ix = Math.floor((x / N) * cellsX);
      const iy = Math.floor((y / N) * cellsY);
      let v = ((ix + iy) & 1) ? 0.82 : 0.18;
      const dx = x - cx, dy = y - cy;
      if (dx * dx + dy * dy <= r * r) {
        // soft inner gradient inside the bite so the spectrum has a tail
        const d = Math.sqrt(dx * dx + dy * dy) / r;
        v = 0.45 + 0.10 * d;
      }
      if (v < 0) v = 0;
      if (v > 1) v = 1;
      a[y * N + x] = v;
    }
  }
  return a;
}

const IMAGE_BUILDERS = [
  buildImageM,
  buildImageCircleOnStripes,
  buildImageOmega,
  buildImageCheckerBite,
];

// ---------- linear algebra ----------
// Jacobi eigendecomposition of an N x N symmetric matrix stored as Float64Array (row-major).
// Returns { eigs: Float64Array(N), vecs: Float64Array(N*N) } where vecs is column-major: column j is the j-th eigenvector, length N.
function jacobiSym(Min, n) {
  // Working copy
  const M = new Float64Array(Min);
  // Eigenvector matrix Q starts as identity, column-major.
  const Q = new Float64Array(n * n);
  for (let i = 0; i < n; i++) Q[i * n + i] = 1;

  const MAX_SWEEPS = 60;
  const EPS = 1e-12;

  for (let sw = 0; sw < MAX_SWEEPS; sw++) {
    let off = 0;
    for (let p = 0; p < n - 1; p++) {
      for (let q = p + 1; q < n; q++) {
        const apq = M[p * n + q];
        if (Math.abs(apq) < 1e-14) continue;
        off += apq * apq;
        const app = M[p * n + p];
        const aqq = M[q * n + q];
        const theta = (aqq - app) / (2 * apq);
        const t = (theta >= 0 ? 1 : -1) / (Math.abs(theta) + Math.sqrt(1 + theta * theta));
        const c = 1 / Math.sqrt(1 + t * t);
        const s = t * c;

        M[p * n + p] = app - t * apq;
        M[q * n + q] = aqq + t * apq;
        M[p * n + q] = 0;
        M[q * n + p] = 0;

        for (let r = 0; r < n; r++) {
          if (r !== p && r !== q) {
            const arp = M[r * n + p];
            const arq = M[r * n + q];
            const nrp = c * arp - s * arq;
            const nrq = s * arp + c * arq;
            M[r * n + p] = nrp;
            M[p * n + r] = nrp;
            M[r * n + q] = nrq;
            M[q * n + r] = nrq;
          }
        }
        // Update eigenvector matrix: Q := Q * G(p,q,c,s)
        // Q is column-major: column j starts at j*n.
        const colP = p * n, colQ = q * n;
        for (let r = 0; r < n; r++) {
          const qrp = Q[colP + r];
          const qrq = Q[colQ + r];
          Q[colP + r] = c * qrp - s * qrq;
          Q[colQ + r] = s * qrp + c * qrq;
        }
      }
    }
    if (off < EPS) break;
  }

  const eigs = new Float64Array(n);
  for (let i = 0; i < n; i++) eigs[i] = M[i * n + i];
  return { eigs, vecs: Q };
}

// Compute A^T A for A (N x N row-major). Returns Float64Array(N*N), symmetric.
function gramian(Aimg, n) {
  const G = new Float64Array(n * n);
  for (let i = 0; i < n; i++) {
    for (let j = i; j < n; j++) {
      let s = 0;
      for (let k = 0; k < n; k++) s += Aimg[k * n + i] * Aimg[k * n + j];
      G[i * n + j] = s;
      G[j * n + i] = s;
    }
  }
  return G;
}

function computeSVD(Aimg) {
  // Aimg is N x N, square. Aimg^T Aimg is N x N symmetric PSD.
  const G = gramian(Aimg, N);
  const { eigs, vecs } = jacobiSym(G, N);

  // Sort eigenpairs by descending eigenvalue. We need indices.
  const idx = new Array(N);
  for (let i = 0; i < N; i++) idx[i] = i;
  idx.sort((a, b) => eigs[b] - eigs[a]);

  const S_ = new Float64Array(N);
  const V_ = new Float64Array(N * N); // column-major: column j is v_j
  const U_ = new Float64Array(N * N); // column-major: column j is u_j

  const tmp = new Float64Array(N);
  for (let j = 0; j < N; j++) {
    const src = idx[j];
    const ev = Math.max(0, eigs[src]);
    const sigma = Math.sqrt(ev);
    S_[j] = sigma;
    // copy v_j
    const colSrc = src * N;
    const colDst = j * N;
    for (let r = 0; r < N; r++) V_[colDst + r] = vecs[colSrc + r];
    // u_j = (1/sigma) * A v_j (when sigma > 0)
    if (sigma > 1e-10) {
      for (let r = 0; r < N; r++) tmp[r] = V_[colDst + r];
      for (let i = 0; i < N; i++) {
        let s = 0;
        const row = i * N;
        for (let kk = 0; kk < N; kk++) s += Aimg[row + kk] * tmp[kk];
        U_[colDst + i] = s / sigma;
      }
    }
    // else leave U column as zero; never used for reconstruction.
  }

  return { S: S_, U: U_, V: V_ };
}

function ensureSVDFor(i) {
  if (svdCache[i]) return svdCache[i];
  const Aimg = IMAGE_BUILDERS[i]();
  const svd = computeSVD(Aimg);
  svdCache[i] = { A: Aimg, S: svd.S, U: svd.U, V: svd.V };
  return svdCache[i];
}

function activateImage(i) {
  const entry = ensureSVDFor(i);
  A = entry.A;
  S = entry.S;
  U = entry.U;
  V = entry.V;
  lastK = -1; // force reconstruction next tick
}

// Reconstruct A_k = sum_{i=0..k-1} sigma_i u_i v_i^T into `recon` (N x N row-major).
function reconstruct(k) {
  recon.fill(0);
  for (let i = 0; i < k; i++) {
    const sigma = S[i];
    if (sigma <= 1e-12) continue;
    const colU = i * N;
    const colV = i * N;
    for (let r = 0; r < N; r++) {
      const uR = U[colU + r];
      if (uR === 0) continue;
      const row = r * N;
      const su = sigma * uR;
      for (let c = 0; c < N; c++) {
        recon[row + c] += su * V[colV + c];
      }
    }
  }
}

function frobeniusError(k) {
  // ||A - A_k||_F^2 = sum_{i>k} sigma_i^2
  let s = 0;
  for (let i = k; i < N; i++) s += S[i] * S[i];
  return Math.sqrt(s);
}

// ---------- rendering ----------
function writeReconToImageData() {
  const d = imgData.data;
  for (let i = 0; i < N * N; i++) {
    let v = recon[i];
    if (v < 0) v = 0;
    if (v > 1) v = 1;
    const g = Math.round(v * 255);
    const o = i * 4;
    d[o] = g;
    d[o + 1] = g;
    d[o + 2] = g;
    d[o + 3] = 255;
  }
  imgCtx.putImageData(imgData, 0, 0);
}

function drawImagePanel(ctx, x, y, w, h) {
  // checker-free black backdrop with subtle border
  ctx.fillStyle = '#0c0d10';
  ctx.fillRect(x, y, w, h);
  // Fit a centered square inside the panel.
  const side = Math.min(w, h) - 18;
  const ox = x + (w - side) / 2;
  const oy = y + (h - side) / 2;
  ctx.imageSmoothingEnabled = false;
  ctx.drawImage(imgBuf, ox, oy, side, side);
  // hairline border
  ctx.strokeStyle = '#1c1e24';
  ctx.lineWidth = 1;
  ctx.strokeRect(ox + 0.5, oy + 0.5, side - 1, side - 1);
}

function drawSpectrumPanel(ctx, x, y, w, h, k) {
  ctx.fillStyle = '#0c0d10';
  ctx.fillRect(x, y, w, h);

  const padL = 36, padR = 14, padT = 28, padB = 28;
  const innerW = w - padL - padR;
  const innerH = h - padT - padB;
  const ox = x + padL;
  const oy = y + padT;

  // y-axis: log-ish scale would be nicer, but linear keeps the visual honest.
  let smax = 0;
  for (let i = 0; i < N; i++) if (S[i] > smax) smax = S[i];
  if (smax <= 0) smax = 1;

  // gridlines
  ctx.strokeStyle = '#191b22';
  ctx.lineWidth = 1;
  ctx.beginPath();
  for (let g = 0; g <= 4; g++) {
    const gy = oy + (innerH * g) / 4;
    ctx.moveTo(ox, gy + 0.5);
    ctx.lineTo(ox + innerW, gy + 0.5);
  }
  ctx.stroke();

  // bars
  const NBARS = N;
  const slot = innerW / NBARS;
  const barW = Math.max(1, slot - 1);
  for (let i = 0; i < NBARS; i++) {
    const bh = (S[i] / smax) * innerH;
    const bx = ox + i * slot;
    const by = oy + innerH - bh;
    if (i < k) {
      ctx.fillStyle = '#6fb3b8';      // muted teal: included
    } else {
      ctx.fillStyle = '#2a2e36';      // dim gray: excluded
    }
    ctx.fillRect(bx, by, barW, bh);
  }

  // axis labels
  ctx.fillStyle = '#7a808a';
  ctx.font = '10px ui-monospace, monospace';
  ctx.fillText('sigma_i', ox - 30, oy - 8);
  ctx.fillText('i', ox + innerW - 6, oy + innerH + 16);
  // a few tick labels on y
  for (let g = 0; g <= 4; g++) {
    const val = smax * (1 - g / 4);
    const gy = oy + (innerH * g) / 4;
    ctx.fillText(val.toFixed(1), ox - 32, gy + 3);
  }
  // x ticks at 1, k, N
  ctx.fillStyle = '#9aa0aa';
  ctx.fillText('1', ox - 2, oy + innerH + 14);
  ctx.fillText(String(N), ox + innerW - 12, oy + innerH + 14);

  // marker line at boundary k (between included/excluded)
  const xK = ox + k * slot;
  ctx.strokeStyle = 'rgba(155, 200, 200, 0.55)';
  ctx.setLineDash([3, 3]);
  ctx.lineWidth = 1;
  ctx.beginPath();
  ctx.moveTo(xK + 0.5, oy);
  ctx.lineTo(xK + 0.5, oy + innerH);
  ctx.stroke();
  ctx.setLineDash([]);
}

function drawHUD(ctx, x, y, w, k, err, scrubbing) {
  // Title
  ctx.fillStyle = '#e8ecf4';
  ctx.font = '14px ui-sans-serif, system-ui';
  ctx.fillText('SVD: best rank-k approximation', x + 8, y + 18);

  // Prominent k readout (right side, large)
  const kStr = `k = ${k}`;
  const subStr = `/ ${N}`;
  ctx.font = 'bold 22px ui-monospace, monospace';
  ctx.fillStyle = scrubbing ? '#9be4ea' : '#cfd6e4';
  const kw = ctx.measureText(kStr).width;
  ctx.font = '12px ui-monospace, monospace';
  const sw = ctx.measureText(subStr).width;
  const rightPad = 12;
  const totalW = kw + 4 + sw;
  const kx = x + w - rightPad - totalW;
  ctx.font = 'bold 22px ui-monospace, monospace';
  ctx.fillStyle = scrubbing ? '#9be4ea' : '#cfd6e4';
  ctx.fillText(kStr, kx, y + 28);
  ctx.font = '12px ui-monospace, monospace';
  ctx.fillStyle = '#7a808a';
  ctx.fillText(subStr, kx + kw + 4, y + 28);

  // Frobenius error + mode hint, on a second line under the title
  ctx.fillStyle = '#9aa3b8';
  ctx.font = '12px ui-monospace, monospace';
  const modeStr = scrubbing ? 'scrub' : 'auto';
  ctx.fillText(`||A - A_k||_F = ${err.toFixed(3)}  ·  ${modeStr}`, x + 8, y + 38);
}

// ---------- lifecycle ----------
function init({ ctx, width, height }) {
  W = width;
  H = height;

  // Build image #0 and its SVD eagerly so we have something to draw immediately.
  activateImage(0);

  recon = new Float64Array(N * N);
  imgBuf = new OffscreenCanvas(N, N);
  imgCtx = imgBuf.getContext('2d');
  imgData = imgCtx.createImageData(N, N);

  reconstruct(1);
  writeReconToImageData();
  lastK = 1;
  lastErr = frobeniusError(1);
  currentK = 1;
  tElapsed = 0;
  idleTimer = IDLE_SEC; // start in auto-sweep until the user touches it
  scrubActive = false;
  lastMouseX = -1;
  lastMouseY = -1;
  lastMouseInside = false;

  ctx.fillStyle = '#08090c';
  ctx.fillRect(0, 0, W, H);
}

function tick({ ctx, dt, width, height, input }) {
  if (width !== W || height !== H) { W = width; H = height; }

  // ----- handle input -----
  const mx = input.mouseX;
  const my = input.mouseY;
  const inside = mx >= 0 && my >= 0 && mx < W && my < H;

  // Detect mouse motion (or the mouse entering the canvas).
  const MOVE_EPS = 0.5;
  const moved =
    inside &&
    (!lastMouseInside ||
      Math.abs(mx - lastMouseX) > MOVE_EPS ||
      Math.abs(my - lastMouseY) > MOVE_EPS);

  if (moved) {
    idleTimer = 0;
    scrubActive = true;
  } else {
    idleTimer += dt;
    if (idleTimer >= IDLE_SEC) scrubActive = false;
  }
  lastMouseX = mx;
  lastMouseY = my;
  lastMouseInside = inside;

  // Click cycles through procedural images.
  const clicks = input.consumeClicks ? input.consumeClicks() : 0;
  if (clicks > 0) {
    imageIdx = (imageIdx + clicks) % NUM_IMAGES;
    if (imageIdx < 0) imageIdx += NUM_IMAGES;
    activateImage(imageIdx); // SVD cached after first compute
    // Treat the click as a fresh interaction so we stay in scrub if pointer is over the canvas.
    if (inside) {
      idleTimer = 0;
      scrubActive = true;
    }
  }

  // ----- compute k -----
  let k;
  if (scrubActive && inside) {
    // mouseX -> k, in [1, N]
    const frac = mx / Math.max(1, W);
    k = Math.floor(frac * N) + 1;
    if (k < 1) k = 1;
    if (k > N) k = N;
  } else {
    // auto-sweep ramp
    tElapsed += dt;
    const cycle = RAMP_SEC + HOLD_SEC;
    const phaseT = tElapsed % cycle;
    if (phaseT < RAMP_SEC) {
      const frac = phaseT / RAMP_SEC;
      k = 1 + Math.floor(frac * (KMAX - 1) + 0.0001);
      if (k > KMAX) k = KMAX;
    } else {
      k = KMAX;
    }
  }
  currentK = k;

  if (k !== lastK) {
    reconstruct(k);
    writeReconToImageData();
    lastErr = frobeniusError(k);
    lastK = k;
  }

  // ----- layout + draw -----
  ctx.fillStyle = '#08090c';
  ctx.fillRect(0, 0, W, H);

  const headerH = 48;
  const gutter = 10;
  const panelY = headerH;
  const panelH = H - headerH - gutter;

  // Allow stacking on narrow viewports.
  const stack = W < 520;
  if (stack) {
    const panelH2 = Math.floor((panelH - gutter) / 2);
    drawImagePanel(ctx, gutter, panelY, W - 2 * gutter, panelH2);
    drawSpectrumPanel(ctx, gutter, panelY + panelH2 + gutter, W - 2 * gutter, panelH - panelH2 - gutter, currentK);
  } else {
    const panelW = Math.floor((W - 3 * gutter) / 2);
    drawImagePanel(ctx, gutter, panelY, panelW, panelH);
    drawSpectrumPanel(ctx, gutter + panelW + gutter, panelY, W - panelW - 3 * gutter, panelH, currentK);
  }

  drawHUD(ctx, 0, 0, W, currentK, lastErr, scrubActive && inside);
}

Comments (0)

Log in to comment.