Viewing a single comment thread. View all comments

cnapun t1_jage50a wrote

I'm not an expert on this topic, but I've discussed it with coworkers. I do believe you should be able to backprop through sampling, mathematically at least. My suspicion is that you'll run into the same problem as you have with RNNs, where backpropping through many steps leads to high variance in gradients. I'd search for some papers that have explored this; I assume they exist.

5

SaltyStackSmasher OP t1_jagisv7 wrote

thanks for the response. my main concern with beam sampling and backprop is the fact that context for the 2nd token will include 1st token. I believe in the RNN case, this wouldn't necessarily matter since only the hidden state is being propagated forward. In transformers, we have to completely redo the forward pass for 2nd token onwards and these subsequent forward passes don't have anything in common, so I'm a bit confused about how the gradients will flow exactly.

please let me know if I wasn't clear in explaining my problem. thanks again for your response :)

2

cnapun t1_jai24sf wrote

What I was trying to say was that doing this sampling approach (in a transformer) seems like it would have similar issues to a RNN, in that your computational graph will be repeated N times, where N is the rollout size. This makes me suspect that you'll get a lot of noise in your gradient estimates if N is large (also iirc Gumbel softmax gradients are biased, which might cause some more issues if chaining them)

1