Submitted by beautyofdeduction t3_10uuslf in deeplearning
Hey Y'all, the transformer model that I'm training has:
- Keras param count is 22 million: 6 encoder blocks, each of which has 8 head with 64 head_size each
- sequence of length 6250
- batch size 1
It consistently OOMs on any GPU with less than 40G of vram (Rtx A6000 for example). I've tried on both Google Colab and Lambda Labs.
22M params plus activations, the most expensive of which has size 6250 * 6250. So that comes down to ~62M floats, i.e. 500 MB. I cannot wrap my head around what caused the vram OOM!
I must be missing something but I can't see it. Please help me out. How much vram usage did you see with your transformers? Any thoughts are appreciated!
BellyDancerUrgot t1_j7ec7oj wrote
~ 83gb I think, not 500mb