QuadmasterXLII

QuadmasterXLII t1_ja7yo0f wrote

Your problem is the U-Net backbone, not the loss function. Assuming that you're married to a batch size of 4, the final convolution to get to 4 x 200 x 500 x 500, crossentropy, and the backpropagation should only take maybe 10 GB, so cram your architecture into the remaining 30GB

import torch
x = torch.randn([4, 128, 500, 500]).cuda()
z = torch.nn.Conv2d(128, 200, 3)
z.cuda()
q = torch.randint(0, 200, (4, 498, 498)).cuda()
torch.nn.CrossEntropyLoss()(z(x), q).backward()

for example, takes 7.5 GB.

2