Submitted by SaltyStackSmasher t3_11euzja in MachineLearning
SaltyStackSmasher OP t1_jagisv7 wrote
Reply to comment by cnapun in [D] backprop through beam sampling ? by SaltyStackSmasher
thanks for the response. my main concern with beam sampling and backprop is the fact that context for the 2nd token will include 1st token. I believe in the RNN case, this wouldn't necessarily matter since only the hidden state is being propagated forward. In transformers, we have to completely redo the forward pass for 2nd token onwards and these subsequent forward passes don't have anything in common, so I'm a bit confused about how the gradients will flow exactly.
please let me know if I wasn't clear in explaining my problem. thanks again for your response :)
cnapun t1_jai24sf wrote
What I was trying to say was that doing this sampling approach (in a transformer) seems like it would have similar issues to a RNN, in that your computational graph will be repeated N times, where N is the rollout size. This makes me suspect that you'll get a lot of noise in your gradient estimates if N is large (also iirc Gumbel softmax gradients are biased, which might cause some more issues if chaining them)
Viewing a single comment thread. View all comments