Viewing a single comment thread. View all comments

RaeudigerRaffi t1_jah39t7 wrote

You are right Gumbel Softmax is a possibility with which you can backprop. But given that he is trying to do beam sampling and backprop through it at some point you need to argmax on your gumbel softmax vector in order to actually pick the token (assuming there is no way to work with the vector representations down the line correct me if i am wrong) and then this becomes not differentiable

3

RaeudigerRaffi t1_jahpbod wrote

To add to this I thought a bit about it and technically in PyTorch, this should be possible to do with some trickery with custom autograd functions. You can probably sample with Gumbel Softmax and return the argmax. In the custom backward you can just skip the argmax part and backprop as if the Gumbel Softmax output has been returned and not the argmax on the Gumbel Softmax.

1