2 Neural SINDy: Sparse ID with Grokked MLP Libraries

The Task

SINDy (Sparse Identification of Nonlinear Dynamics) discovers the ODE governing a time series by picking a sparse combination from a hand-crafted library of basis functions:

X˙=Θ(X)ξ,Θ(X)=[1,x,x2,sinx,cosx,]\dot{X}\;=\;\Theta(X)\cdot\xi,\qquad\Theta(X)=[1,\;x,\;x^{2},\;\sin x,\;\cos x% ,\;\ldots]

Standard SINDy needs a human to curate Θ\Theta. Neural SINDy replaces each basis function with a small MLP that has grokked one operation — identity, sin, cos, add, mul — so the library becomes a learned component. The sparse regression then selects which learned operations best explain the observed dynamics.

The target is a damped harmonic oscillator with k=1.0k=1.0, c=0.1c=0.1:

x¨=kxcx˙x˙=v,v˙=kxcv\ddot{x}\;=\;-k\,x-c\,\dot{x}\qquad\Longleftrightarrow\qquad\dot{x}=v,\;\;\dot% {v}=-k\,x-c\,v

600 time points, small Gaussian noise. Success means recovering those two equations with clean, interpretable coefficients.

Pipeline (Phases 1–3)

Phase 1.

Train 5 MLPs past overfitting on basic operations — AdamW with heavy weight decay for 10k–20k epochs each. Training loss drops fast; validation loss drops much later (grokking). Each MLP ends as a clean algorithmic circuit for its target function.

Phase 2.

Simulate the oscillator; collect [t,x,v,x˙,v˙][t,x,v,\dot{x},\dot{v}].

Phase 3.

Build Θ(X)\Theta(X) by passing the state data through each grokked MLP (8 terms total: identity, sin, cos on xx and vv, plus add(x,v) and mul(x,v)), then solve for a sparse ξ\xi.

Run 1: STLSQ Baseline — Correct Answer, Ugly Presentation

Sequential Thresholded Least Squares fits exactly — MSE 10610^{-6} — but the coefficients are distributed across redundant terms:

x˙0.333id(x)+0.667id(v)+0.332add(x,v)\dot{x}\;\approx\;-0.333\cdot\mathrm{id}(x)+0.667\cdot\mathrm{id}(v)+0.332% \cdot\mathrm{add}(x,v)
v˙0.632id(x)+0.268id(v)0.368add(x,v)\dot{v}\;\approx\;-0.632\cdot\mathrm{id}(x)+0.268\cdot\mathrm{id}(v)-0.368% \cdot\mathrm{add}(x,v)

Summing simplifies cleanly to the truth (x˙=v\dot{x}=v, v˙=x0.1v\dot{v}=-x-0.1v), but the individual coefficients are uninterpretable. The reason is an exact collinearity in the library: add(x,v)=id(x)+id(v)\mathrm{add}(x,v)=\mathrm{id}(x)+\mathrm{id}(v), so Θ\Theta is rank-deficient. STLSQ — even with ridge regularization — cannot distinguish the redundant directions and splits coefficients arbitrarily. Dropping add from the library is the one-line fix; a more principled fix would move to a method that doesn’t invert Θ\Theta at all.

Hypothesis. A differentiable Gumbel-Softmax router over the library would pick terms categorically rather than solving a linear system, sidestepping rank deficiency.

Run 2 (Exp 1): Gumbel-Softmax, Softmax Soup

[Uncaptioned image]

A state-dependent router (Linear \to ReLU \to Linear \to ReLU \to Linear) with Gumbel-Softmax sampling and temperature annealing τ:5.00.05\tau:5.0\to 0.05. Expected near-one-hot activations; got a “softmax soup” instead.

max activation final MSE
x˙\dot{x} 43.8%43.8\% (identity(v)) 5103\sim 5\cdot 10^{-3}
v˙\dot{v} 19.2%19.2\% (identity(x))

The router found the right dominant term (id(v)\mathrm{id}(v) with coefficient +1.0000+1.0000) but never committed. Three failures stacked:

No sparsity pressure.

Nothing in the loss pushes the logits apart. With high τ\tau the router computes a weighted average across all MLPs and tunes per-MLP coefficients to fit the data. When τ\tau drops at epoch 2000, the underlying distribution never concentrated, and the soft-vs-hard mismatch shows as a loss discontinuity.

Damping lost.

v˙\dot{v}’s damping term 0.1v-0.1\,v is 10×10\times smaller than the restoring force. Without sparsity, that signal is drowned across 7 other terms.

Collinearity in disguise.

add(x,v) is still selected 10–14% of the time with meaningful coefficients. Rank deficiency didn’t disappear — it became probability-mass splitting.

Net: MSE 1000×\sim 1000\times worse than STLSQ. Gumbel-Softmax without entropy regularization is just a fancy softmax.

Run 3 (Exp 2): Scalar Router + Entropy — Wrong Basis

[Uncaptioned image]

Three changes from Exp 1:

  1. 1.

    State-independent router. A single learnable logit vector per derivative, broadcast across the batch. Matches the SINDy assumption that one term governs the dynamics everywhere in state space. Param count drops from 9,760329{,}760\to 32.

  2. 2.

    Entropy penalty with correct sign, scheduled weight: w=0.05max(0,1τ/τstart)w=0.05\cdot\max(0,1-\tau/\tau_{\text{start}}). Zero at the start (pure exploration); 0.05\sim 0.05 at the end (strong pressure toward one-hot).

  3. 3.

    Same Gumbel-Softmax + anneal schedule.

Result: clean one-hot commitment.

selected term activation
x˙\dot{x} +1.0265sin(v)+1.0265\cdot\sin(v) 99.7%99.7\%
v˙\dot{v} 1.0138sin(x)-1.0138\cdot\sin(x) 99.2%99.2\%

Wrong basis. The true equations are linear in x,vx,v, so id(x)\mathrm{id}(x) and id(v)\mathrm{id}(v) should have won. The router picked sin instead because sin(z)z\sin(z)\approx z within 15%\sim 15\% for |z|1|z|\lesssim 1 (the amplitudes the oscillator visits). The inflated coefficient (1.02651.0265 vs. true 1.01.0) absorbs the Taylor-series shortfall. The entropy penalty broke the near-tie arbitrarily.

Damping is still gone: one-hot structurally cannot represent a two-term equation like v˙=1x0.1v\dot{v}=-1\cdot x-0.1\cdot v.

Two collinearities, not one.

Exp 1 hit the exact collinearity add(x,v)=id(x)+id(v)\mathrm{add}(x,v)=\mathrm{id}(x)+\mathrm{id}(v). Exp 2 sidesteps that but hits an approximate collinearity sin(z)z\sin(z)\approx z. The library design is the real bottleneck; both optimizers are doing their jobs.

Run 4 (Exp 3): Top-kk + Complexity Prior — Dominant Term Recovered

[Uncaptioned image]

Two changes:

Top-kk routing (k=2k{=}2).

Each derivative selects two basis functions instead of one. Implemented as kk independent Gumbel-Softmax draws with prior selections masked out (sampling- without-replacement via iterative masked argmax with STE gradients). Each slot has its own learnable coefficient. Now v˙=1x0.1v\dot{v}=-1\cdot x-0.1\cdot v is representable.

Complexity prior (α=1.0\alpha{=}1.0).

Occam’s razor baked into the logits: identitysin/cos\mathrm{identity}\prec\sin/\cos\prec binary MLPs. Added directly to the logits so it biases selection without distorting learned coefficients.

Result:

x˙=+1.0000id(v)\dot{x}\;=\;+1.0000\cdot\mathrm{id}(v)
v˙=0.9853id(x)0.0267id(v)0.0151sin(v)\dot{v}\;=\;-0.9853\cdot\mathrm{id}(x)-0.0267\cdot\mathrm{id}(v)-0.0151\cdot% \sin(v)
term discovered true error
id(v)\mathrm{id}(v) in x˙\dot{x} +1.0000+1.0000 +1.0+1.0 0.00%0.00\%
id(x)\mathrm{id}(x) in v˙\dot{v} 0.9853-0.9853 1.0-1.0 1.5%1.5\%
id(v)\mathrm{id}(v) in v˙\dot{v} 0.0267-0.0267 0.1-0.1 73%73\%

Slot 1 locks cleanly on the correct basis in both derivatives. Slot 2 identifies id(v)\mathrm{id}(v) as the right damping term but fails to commit — selected only 27% of the time, with a coefficient 1/4\sim 1/4 the true magnitude.

Run 5 (Exp 4): Conditional Slot Entropy — Diagnosis

[Uncaptioned image]

Root cause found.

Exp 3’s entropy penalty operated on the base distribution softmax(logits+prior)\mathrm{softmax}(\mathrm{logits}+\mathrm{prior}) — before any masking. Minimizing that drove the base distribution one-hot on the dominant term. Once slot 1 committed and masked its winner to -\infty, slot 2 faced a near-flat conditional distribution with no gradient pointing it anywhere useful.

Fix.

Iterate slot-by-slot, mirroring the forward pass:

Hj=softmax(logits+maskj)logsoftmax(logits+maskj),maskj+1+=atargmaxjH_{j}=-\!\!\sum\mathrm{softmax}(\mathrm{logits}+\mathrm{mask}_{j})\cdot\log% \mathrm{softmax}(\mathrm{logits}+\mathrm{mask}_{j}),\quad\mathrm{mask}_{j+1}% \mathrel{+}=-\infty\;\text{at}\;\arg\max_{j}

Each slot’s conditional distribution is penalized independently. Slot 2 now receives a direct gradient to concentrate on whatever term best explains the residual.

Result: slot 1 unchanged (100% commitment, coefficients unchanged). Slot 2’s id(v)\mathrm{id}(v) activation improves from 27% \to 34%; damping coefficient from 0.0270.034-0.027\to-0.034. Val MSE 2.421042.42\cdot 10^{-4}, still 24×\sim 24\times above the 10510^{-5} target.

Why slot 2 still won’t commit.

Three compounding factors:

  1. 1.

    Weak residual signal. After slot 1 absorbs 99%\sim 99\% of v˙\dot{v} variance, the residual slot 2 must explain is 10×\sim 10\times smaller than the primary signal. Reconstruction-loss gradient to slot 2 is proportionally weak.

  2. 2.

    STE gradient starvation at low τ\tau. At τ=0.05\tau=0.05, the Straight-Through Estimator back-propagates through soft weights that are essentially zero for non-winners. Slot 2’s losing logits barely move.

  3. 3.

    Entropy budget split equally. The entropy weight applies to both slots, but slot 1 is already committed and doesn’t need pressure. Budget wasted.

Summary of Runs

Run Method Val MSE Max act. Correct?
1 STLSQ 10610^{-6} structure: yes; coeffs: spread
2 Gumbel, MLP router 5103\sim 5\cdot 10^{-3} 44%44\% soft mixture, damping gone
3 Gumbel, scalar + entropy 1103\sim 1\cdot 10^{-3} 99.7%99.7\% wrong basis (sin vs id)
4 + top-kk + complexity prior 3104\sim 3\cdot 10^{-4} 100%100\% right basis; damping at 27%27\%
5 + conditional slot entropy 2.41042.4\cdot 10^{-4} 100%100\% damping at 34%34\%

What I Learned

Each failure revealed something different about the library-vs-optimizer interplay:

  1. 1.

    STLSQ fails on exact collinearity because it inverts a rank-deficient Θ\Theta; it distributes coefficients rather than picking.

  2. 2.

    Gumbel-Softmax without an entropy penalty stays soft. Categorical sampling doesn’t imply categorical commitment; the loss has to ask for it.

  3. 3.

    One-hot with entropy cannot represent multi-term equations and picks arbitrarily among approximately collinear terms (sinid\sin\approx\mathrm{id} for small amplitudes).

  4. 4.

    Top-kk + complexity prior fixes both, but the joint loss starves slot 2’s gradient when the residual signal is 10×10\times smaller than the primary term.

The library has two collinearities — one exact, one approximate — and every method “failed” by exposing one of them. The real fix lives upstream of the router: drop add from the library, train slot 2 on an explicit residual v˙ξ1basis1(X)\dot{v}-\xi_{1}\cdot\mathrm{basis}_{1}(X), or freeze slot 1 after commitment and fine-tune slot 2 alone. All three give the weak damping term a full-strength gradient instead of a fraction of the joint loss.

The template, again: when a method fails, the failure mode is usually more informative than the failure magnitude. STLSQ’s MSE of 10610^{-6} looked like success and was hiding a rank-deficient library; the router’s MSE of 51035\cdot 10^{-3} looked like failure and pointed directly at the missing sparsity objective. Reading the residuals — per term, not in aggregate — is where the signal lives.