| category | research | ||||||||
|---|---|---|---|---|---|---|---|---|---|
| section | introduction | ||||||||
| weight | 10 | ||||||||
| title | jaxctrl: Differentiable Control Theory in JAX | ||||||||
| status | draft | ||||||||
| slide_summary | Fully differentiable Lyapunov/Riccati solvers, tensor eigenvalue methods, and hypergraph controllability analysis in JAX — filling gaps between SciPy control and modern autodiff ecosystems. | ||||||||
| tags |
|
Differentiable control theory in JAX. Lyapunov and Riccati solvers, controllability analysis, tensor eigenvalues, and hypergraph control — all JIT-compiled and autodiff-compatible.
Built on the Kidger stack: Equinox, Lineax, Optimistix, Diffrax.
pip install jaxctrlFor hypergraph control (requires hgx):
pip install jaxctrl[hypergraph]Layer 0 — System identification (data-driven model discovery):
SINDyOptimizer,polynomial_library,fourier_libraryKoopmanEstimator(Exact DMD)
Layer 1 — Control primitives (missing from JAX, exist in SciPy):
solve_continuous_lyapunov,solve_discrete_lyapunovsolve_continuous_are,solve_discrete_arelqr,dlqrcontrollability_gramian,observability_gramianis_controllable,is_observable,is_stabilizable,is_detectablesimulate_lti,simulate_closed_loop(Diffrax adaptive ODE or matrix-exponential fallback)
Layer 2 — Tensor control (new mathematics, no implementation exists anywhere):
z_eigenvalues,h_eigenvalues,spectral_radiustensor_unfold,tensor_fold,einstein_product,tensor_contractmode_dot,hosvd,tucker_to_tensor,khatri_raosolve_arte,tensor_lyapunov,multilinear_lqr
Layer 3 — Hypergraph control (integrates with hgx):
adjacency_tensor,laplacian_tensortensor_kalman_rank,minimum_driver_nodescontrol_energy,controllability_profileHypergraphControlSystem
import jax
import jax.numpy as jnp
import jaxctrl
# Double integrator: dx/dt = Ax + Bu
A = jnp.array([[0.0, 1.0], [0.0, 0.0]])
B = jnp.array([[0.0], [1.0]])
Q = jnp.eye(2)
R = jnp.eye(1)
# LQR controller (fully differentiable)
K, X = jaxctrl.lqr(A, B, Q, R)
# Controllability analysis
print(jaxctrl.is_controllable(A, B)) # True
# Simulate closed-loop response (uses Diffrax if available)
x0 = jnp.array([2.0, 0.0])
ts, xs, us = jaxctrl.simulate_closed_loop(A, B, K, x0, T=10.0)
# Differentiate the LQR cost w.r.t. Q
dJ_dQ = jax.grad(lambda Q: jnp.sum(jaxctrl.lqr(A, B, Q, R)[1]))(Q)- Kao & Hennequin (2020). "Automatic differentiation of Sylvester, Lyapunov, and algebraic Riccati equations." arXiv:2011.11430
- Chen & Surana (2021). "Controllability of hypergraphs." IEEE TNSE.
- Wang & Wei (2024). "Algebraic Riccati tensor equations." arXiv:2402.13491
- Dong et al. (2024). "Controllability and observability of temporal hypergraphs." arXiv:2408.12085
- Liu, Slotine & Barabási (2011). "Controllability of complex networks." Nature 473, 167–173.