Submitted by jakiwjakiw t3_xwam5y in MachineLearning

I wrote an introduction to Diffusion Models in JAX for a recent workshop. It guides you through implementing a Diffusion Model from Scratch and training it on some toy datasets that every Laptop can handle :)

Introduction to Score-based generative models

There are also some excursions to understand a bit of the theory behind diffusion models and studying their generalization properties (or when they memorize their training data). Appreciate all kinds of feedback!

70

Comments

You must log in or register to comment.

velcher t1_ir5und6 wrote

Thanks for the post!

In the section: Marginals of the time-changed OU-process

> The empirical measure \hat\mu = 1/J \sum_{j=1}^J \delta_{x_i}

  • What is \delta_{x_i}? It doesn't seem to be mentioned before. Is this just the difference between the J samples x^j?
2

Serverside t1_ir606f1 wrote

Nice code! What would it take to make this conditional? I.e. for your circle example, produce a conditional distribution of points based on the arc-length (or some other label)?

1

Small_Stand_8716 t1_ir69lha wrote

Great post, thank you! Unrelated, but is there a reason you used Flax over Haiku? I've been meaning to learn JAX for a while and go beyond PyTorch, but I'm unsure whether, for deep learning, I should start with Flax or Haiku.

1

jakiwjakiw OP t1_ir6k6b9 wrote

The delta's stand are point measures. \delta_x has mass 1 at the point x and 0 everywhere else.
Therefore \hat\mu is the measure that has mass 1/J at each x_j and 0 everywhere else.

These point measures are also called "dirac delta" and the notation using the delta is common in my field, but it is rather confusing to just use them like that without explanation, thanks a lot for pointing that out. Will update that!

3

jakiwjakiw OP t1_ir6lny2 wrote

I mainly used flax since most of the code I was reading myself was also using flax, but I wouldn't be able to objectively compare them since I haven't used haiku yet. So can't help, sorry!

1

jakiwjakiw OP t1_ir6m6wg wrote

That's a pretty interesting question! Actually, I don't know. As far as I know there are currently multiple ways to do conditional generation with SGMs, depending on what your requirements are. I like the work https://arxiv.org/abs/2111.13606 on that regard. But it's also something I wanted to explore in a bit more depth.

Very happy about any input about this!

2

Serverside t1_ir71pde wrote

Yeah I've read that paper you linked, but I have not really delved into trying to implement conditional SGM code myself (I've done work with conditional generative models in terms of GANs, VAEs, etc). I am also interested in lower dimensional data than images, so your code looked like a good starting point.

After some more reading, I'll give adding conditional capabilities to your code a shot.

1

Sea_Discussion_459 t1_ir98mjt wrote

Thanks for your introduction!!

But I find the code link 404

1

Small_Stand_8716 t1_irajgt0 wrote

Haha, yes, I've stumbled upon Equinox many times before during my JAX research. It is far more appealing to me, syntax wise, than Haiku or Flax and has the ease of use of PyTorch. Its only downside, in my opinion, is that it doesn't have as robust an ecosystem, e.g., pre-trained vision model. Excellent package nevertheless!

2

nnexx_ t1_irbf3qt wrote

Yes equinox is amazing. Also obligatory mention of diffrax that implements ODE SDE and CDE in equinox. Diffrax’s documentation even has a working continuous time diffusion model code example !

3

Small_Stand_8716 t1_ire7h35 wrote

Thank you, eqxvision fits the bill for me actually. For the most part, I need pre-trained backbones (usually just RegNet, EfficientNetV2, or ViT) and code the rest myself (YOLO, DETR, and so forth). Thanks again!

1