patrickkidger
patrickkidger t1_j8fde35 wrote
Reply to [D] Have their been any attempts to create a programming language specifically for machine learning? by throwaway957280
On static shape checking: have a look at jaxtyping, which offers compile-time shape checks for JAX/PyTorch/etc.
(Why "JAX"typing? Because it originally only supported JAX. But it now supports other frameworks too! In particular I now recommend jaxtyping over my older "TorchTyping" project, which is pretty undesirably hacky.)
In terms of fitting this kind of stuff into a proper language: that'd be lovely. I completely agree that the extent to which we have retrofitted Python is pretty crazy!
patrickkidger t1_ivg6fsg wrote
Reply to comment by xl0 in [P] Lovely Tensors library by xl0
There's no JAX-specific information worth including, I don't think. A JAX array basically holds the same information as a PyTorch tensor, i.e. shape/dtype/device/nans/infs/an array of values.
The implementation would need to respect how JAX works, though. JAX works by substituting arrays for duck-typed "tracer" objects, passing them in to your Python function, recording everything that happens to them, and them compiling the resulting computation graph. (=no Python interpeter during execution, and the possibility of op fusion, which often means improved performance. Also it what makes its function transformations possible, like autodiff, autoparallelism, autobatching etc.)
This means that you don't usually have the value of the array when you evaluate your Python function -- just some metadata like its shape and dtype. Instead you'd have to create an op that delays doing the printing until runtime, i.e. not doing it during trace time.
..which sounds like a lot, but is probably very easy. Just wrap jax.debug.print
.
patrickkidger t1_ivb4la8 wrote
Reply to [P] Lovely Tensors library by xl0
This is really nice!
...I might shamelessly steal this idea for my JAX work. :D
patrickkidger t1_ivb3t7e wrote
Reply to comment by Gaussianperson in [D] Physics-inspired Deep Learning Models by ShadowKnightPro
See the conclusion of my thesis (linked above ;) )
TL;DR: everything neural PDEs, stable training of neural SDEs, applications of neural ODEs to ~all of science~, adaptive/implicit/rough numerical SDEs (although that one's very specialised), there's current work connecting NDEs with state space models (S4D, MEGA, etc.), ... etc. etc!
patrickkidger t1_iv6iau0 wrote
Reply to comment by ShadowKnightPro in [D] Physics-inspired Deep Learning Models by ShadowKnightPro
Yep, absolutely.
patrickkidger t1_iv5vb05 wrote
Neural differential equations! The continuous-time limit of a lot of deep learning models can be thought of as a differential equation with a neural network as its vector field.
A survey is On Neural Differential Equations.
Also +1 for /u/betelgeuse3e08's recommendations, which are primarily neural ODEs encoding particular kinds of physical structure; c.f. Section 2.2.2 of the above.
You can find a lot of code examples of neural ODEs/SDEs/etc. in JAX in the Diffrax documentation.
This topic is kind of my thing :) DM me if you end up going down this route, I can try to point you at the open problems.
patrickkidger t1_ireuf55 wrote
Reply to comment by Small_Stand_8716 in [R] Introduction to Diffusion Models in JAX by jakiwjakiw
Nice, I'm glad it helps!
(CC /u/PaganPasta as this is their project.)
patrickkidger t1_irake85 wrote
Reply to comment by Small_Stand_8716 in [R] Introduction to Diffusion Models in JAX by jakiwjakiw
Have you seen eqxvision? It's still relatively nascent but already has a fair amount.
patrickkidger t1_ir8dz9i wrote
Reply to comment by Small_Stand_8716 in [R] Introduction to Diffusion Models in JAX by jakiwjakiw
Obligatory advertisement for Equinox as a third alternative.
(And IMO much easier to use, although I am biased.)
patrickkidger t1_j8fdtx2 wrote
Reply to comment by 0x00A0C0 in [D] Have their been any attempts to create a programming language specifically for machine learning? by throwaway957280
Heads-up that my newer jaxtyping project now exists.
Despite the name is supports both PyTorch or JAX; it is also substantially less hackish than TorchTyping! As such I recommend jaxtyping instead of TorchTyping regardless of your framework.
(jaxtyping is now widely used internally.)