2 Neural SINDy: Sparse ID with Grokked MLP Libraries
The Task
SINDy (Sparse Identification of Nonlinear Dynamics) discovers the ODE governing a time series by picking a sparse combination from a hand-crafted library of basis functions:
Standard SINDy needs a human to curate . Neural SINDy replaces each basis function with a small MLP that has grokked one operation — identity, sin, cos, add, mul — so the library becomes a learned component. The sparse regression then selects which learned operations best explain the observed dynamics.
The target is a damped harmonic oscillator with , :
600 time points, small Gaussian noise. Success means recovering those two equations with clean, interpretable coefficients.
Pipeline (Phases 1–3)
Phase 1.
Train 5 MLPs past overfitting on basic operations — AdamW with heavy weight decay for 10k–20k epochs each. Training loss drops fast; validation loss drops much later (grokking). Each MLP ends as a clean algorithmic circuit for its target function.
Phase 2.
Simulate the oscillator; collect .
Phase 3.
Build by passing the state data through each grokked MLP (8 terms total: identity, sin, cos on and , plus add(x,v) and mul(x,v)), then solve for a sparse .
Run 1: STLSQ Baseline — Correct Answer, Ugly Presentation
Sequential Thresholded Least Squares fits exactly — MSE — but the coefficients are distributed across redundant terms:
Summing simplifies cleanly to the truth (, ), but the individual coefficients are uninterpretable. The reason is an exact collinearity in the library: , so is rank-deficient. STLSQ — even with ridge regularization — cannot distinguish the redundant directions and splits coefficients arbitrarily. Dropping add from the library is the one-line fix; a more principled fix would move to a method that doesn’t invert at all.
Hypothesis. A differentiable Gumbel-Softmax router over the library would pick terms categorically rather than solving a linear system, sidestepping rank deficiency.
Run 2 (Exp 1): Gumbel-Softmax, Softmax Soup
![[Uncaptioned image]](figures/sindy_exp1_softmax_soup.png)
A state-dependent router (Linear ReLU Linear ReLU Linear) with Gumbel-Softmax sampling and temperature annealing . Expected near-one-hot activations; got a “softmax soup” instead.
| max activation | final MSE | |
| (identity(v)) | ||
| (identity(x)) |
The router found the right dominant term ( with coefficient ) but never committed. Three failures stacked:
No sparsity pressure.
Nothing in the loss pushes the logits apart. With high the router computes a weighted average across all MLPs and tunes per-MLP coefficients to fit the data. When drops at epoch 2000, the underlying distribution never concentrated, and the soft-vs-hard mismatch shows as a loss discontinuity.
Damping lost.
’s damping term is smaller than the restoring force. Without sparsity, that signal is drowned across 7 other terms.
Collinearity in disguise.
add(x,v) is still selected 10–14% of the time with meaningful coefficients. Rank deficiency didn’t disappear — it became probability-mass splitting.
Net: MSE worse than STLSQ. Gumbel-Softmax without entropy regularization is just a fancy softmax.
Run 3 (Exp 2): Scalar Router + Entropy — Wrong Basis
![[Uncaptioned image]](figures/sindy_exp2_scalar_router.png)
Three changes from Exp 1:
-
1.
State-independent router. A single learnable logit vector per derivative, broadcast across the batch. Matches the SINDy assumption that one term governs the dynamics everywhere in state space. Param count drops from .
-
2.
Entropy penalty with correct sign, scheduled weight: . Zero at the start (pure exploration); at the end (strong pressure toward one-hot).
-
3.
Same Gumbel-Softmax + anneal schedule.
Result: clean one-hot commitment.
| selected term | activation | |
Wrong basis. The true equations are linear in , so and should have won. The router picked sin instead because within for (the amplitudes the oscillator visits). The inflated coefficient ( vs. true ) absorbs the Taylor-series shortfall. The entropy penalty broke the near-tie arbitrarily.
Damping is still gone: one-hot structurally cannot represent a two-term equation like .
Two collinearities, not one.
Exp 1 hit the exact collinearity . Exp 2 sidesteps that but hits an approximate collinearity . The library design is the real bottleneck; both optimizers are doing their jobs.
Run 4 (Exp 3): Top- + Complexity Prior — Dominant Term Recovered
![[Uncaptioned image]](figures/sindy_exp3_topk_prior.png)
Two changes:
Top- routing ().
Each derivative selects two basis functions instead of one. Implemented as independent Gumbel-Softmax draws with prior selections masked out (sampling- without-replacement via iterative masked argmax with STE gradients). Each slot has its own learnable coefficient. Now is representable.
Complexity prior ().
Occam’s razor baked into the logits: binary MLPs. Added directly to the logits so it biases selection without distorting learned coefficients.
Result:
| term | discovered | true | error |
| in | |||
| in | |||
| in |
Slot 1 locks cleanly on the correct basis in both derivatives. Slot 2 identifies as the right damping term but fails to commit — selected only 27% of the time, with a coefficient the true magnitude.
Run 5 (Exp 4): Conditional Slot Entropy — Diagnosis
![[Uncaptioned image]](figures/sindy_exp4_conditional_entropy.png)
Root cause found.
Exp 3’s entropy penalty operated on the base distribution — before any masking. Minimizing that drove the base distribution one-hot on the dominant term. Once slot 1 committed and masked its winner to , slot 2 faced a near-flat conditional distribution with no gradient pointing it anywhere useful.
Fix.
Iterate slot-by-slot, mirroring the forward pass:
Each slot’s conditional distribution is penalized independently. Slot 2 now receives a direct gradient to concentrate on whatever term best explains the residual.
Result: slot 1 unchanged (100% commitment, coefficients unchanged). Slot 2’s activation improves from 27% 34%; damping coefficient from . Val MSE , still above the target.
Why slot 2 still won’t commit.
Three compounding factors:
-
1.
Weak residual signal. After slot 1 absorbs of variance, the residual slot 2 must explain is smaller than the primary signal. Reconstruction-loss gradient to slot 2 is proportionally weak.
-
2.
STE gradient starvation at low . At , the Straight-Through Estimator back-propagates through soft weights that are essentially zero for non-winners. Slot 2’s losing logits barely move.
-
3.
Entropy budget split equally. The entropy weight applies to both slots, but slot 1 is already committed and doesn’t need pressure. Budget wasted.
Summary of Runs
| Run | Method | Val MSE | Max act. | Correct? |
| 1 | STLSQ | — | structure: yes; coeffs: spread | |
| 2 | Gumbel, MLP router | soft mixture, damping gone | ||
| 3 | Gumbel, scalar + entropy | wrong basis (sin vs id) | ||
| 4 | + top- + complexity prior | right basis; damping at | ||
| 5 | + conditional slot entropy | damping at |
What I Learned
Each failure revealed something different about the library-vs-optimizer interplay:
-
1.
STLSQ fails on exact collinearity because it inverts a rank-deficient ; it distributes coefficients rather than picking.
-
2.
Gumbel-Softmax without an entropy penalty stays soft. Categorical sampling doesn’t imply categorical commitment; the loss has to ask for it.
-
3.
One-hot with entropy cannot represent multi-term equations and picks arbitrarily among approximately collinear terms ( for small amplitudes).
-
4.
Top- + complexity prior fixes both, but the joint loss starves slot 2’s gradient when the residual signal is smaller than the primary term.
The library has two collinearities — one exact, one approximate — and every method “failed” by exposing one of them. The real fix lives upstream of the router: drop add from the library, train slot 2 on an explicit residual , or freeze slot 1 after commitment and fine-tune slot 2 alone. All three give the weak damping term a full-strength gradient instead of a fraction of the joint loss.
The template, again: when a method fails, the failure mode is usually more informative than the failure magnitude. STLSQ’s MSE of looked like success and was hiding a rank-deficient library; the router’s MSE of looked like failure and pointed directly at the missing sparsity objective. Reading the residuals — per term, not in aggregate — is where the signal lives.