Submitted by MohamedRashad t3_y14lvd in MachineLearning
MohamedRashad OP t1_irvolp8 wrote
Reply to comment by HoLeeFaak in [D] Reversing Image-to-text models to get the prompt by MohamedRashad
I thought about self-supervision for this task. Enter the image I want it's prompt to an Image-to-text model and the resulting text I feed to a diffusion model (DALL-E, Stable Diffusion) which I freeze their weights so they don't change.
The output image will be compared to the original image I entered and the loss will be backpropagated to the image-to-text model to learn. The problems with this approach (in my humble opinion) are two:
- Training such system won't be easy and I will need a lot of resources I currently don't have.
- And even if I succeed The resulting model won't be good enough for generalization.
This is of course if I managed to overcome the non-differentiable parts.
HoLeeFaak t1_irvoxe5 wrote
What you propose is a cycle-loss. It's valid, but the biggest problem is the non-differentiable parts, and this is a big problem that I didn't find a solution to.
samb-t t1_irvsicm wrote
If you have enough resources to train an autoregressive model then you could take advantage of knowing that these big text-to-image models are conditioned on CLIP embeddings and instead train an autoregressive model to predict prompts conditioned on CLIP image embeddings. That way there's no non-differentiable parts to bypass and the CLIP embeddings should be a pretty great descriptor of both the input image and the prompt.
If you don't have enough resources then (just thinking out loud, probably be a better way but might give some ideas) you could again use a pretrained CLIP model. 1. Embed the input image. 2. Using the CLIP text embedding network optimise the input text to get an embedding close to the image embedding. Problem there is again that text is discrete so you can't backprop. You could use gumbel softmax to approximate the discrete text values though (anneal down how continuous it is). Alternatively you could treat the embedding distance loss as an energy function, and use discrete MCMC, something like gibbs-with-gradients. But both of those options still probably aren't great, it's a horrible optimisation space
[deleted] t1_irvowy0 wrote
[deleted]
Viewing a single comment thread. View all comments