Submitted by xl0 t3_ynheaf in MachineLearning
patrickkidger t1_ivg6fsg wrote
Reply to comment by xl0 in [P] Lovely Tensors library by xl0
There's no JAX-specific information worth including, I don't think. A JAX array basically holds the same information as a PyTorch tensor, i.e. shape/dtype/device/nans/infs/an array of values.
The implementation would need to respect how JAX works, though. JAX works by substituting arrays for duck-typed "tracer" objects, passing them in to your Python function, recording everything that happens to them, and them compiling the resulting computation graph. (=no Python interpeter during execution, and the possibility of op fusion, which often means improved performance. Also it what makes its function transformations possible, like autodiff, autoparallelism, autobatching etc.)
This means that you don't usually have the value of the array when you evaluate your Python function -- just some metadata like its shape and dtype. Instead you'd have to create an op that delays doing the printing until runtime, i.e. not doing it during trace time.
..which sounds like a lot, but is probably very easy. Just wrap jax.debug.print
.
xl0 OP t1_ivj5slm wrote
I started working on it. Will make sure repr works inside jit and parallel before moving to other things.
https://github.com/xl0/lovely-jax
Please let me know if you have any thoughts, I'm very new to JAX.
Viewing a single comment thread. View all comments