Submitted by xl0 t3_ynheaf in MachineLearning

Hi! I made the ❤️ Lovely Tensors: https://github.com/xl0/lovely-tensors library.

It lets you visualize and summarize PyTorch tensors for human consumption:

Tensor summary

Or display images (assuming it's RGB[A] data):

RGB Image

Or plot a histogram:

Histogram and stats

Or view channels (either colour channels of ConvNet activations):

Channels

I just released version 0.1.0, which covers most things I had in mind for tensors. I plan on adding visualizations for nn.Modules and other features in the near future. Would appreciate your feedback on both small bugs/features and the overall usability.

42

Comments

You must log in or register to comment.

patrickkidger t1_ivb4la8 wrote

This is really nice!
...I might shamelessly steal this idea for my JAX work. :D

6

cloneofsimo t1_ivca8c8 wrote

Wow this seems super useful! It would be cool if this was one of torch's native method like pandas df.describe method :P

1

xl0 OP t1_ivd0rca wrote

Haha, thank you! You are not the first person to mention JAX, so I guess I'll do a JAX version next. :)

I have a rough idea of what it is, and as I understand it, it's more about transforming functions. Do you have ideas about anything JAX-specific that should be included in the ndarray summary?

1

patrickkidger t1_ivg6fsg wrote

There's no JAX-specific information worth including, I don't think. A JAX array basically holds the same information as a PyTorch tensor, i.e. shape/dtype/device/nans/infs/an array of values.

The implementation would need to respect how JAX works, though. JAX works by substituting arrays for duck-typed "tracer" objects, passing them in to your Python function, recording everything that happens to them, and them compiling the resulting computation graph. (=no Python interpeter during execution, and the possibility of op fusion, which often means improved performance. Also it what makes its function transformations possible, like autodiff, autoparallelism, autobatching etc.)

This means that you don't usually have the value of the array when you evaluate your Python function -- just some metadata like its shape and dtype. Instead you'd have to create an op that delays doing the printing until runtime, i.e. not doing it during trace time.

..which sounds like a lot, but is probably very easy. Just wrap jax.debug.print.

1