Viewing a single comment thread. View all comments

QuadmasterXLII t1_ja7wog6 wrote

... does it fit with batch size 1?

2

QuadmasterXLII t1_ja7yo0f wrote

Your problem is the U-Net backbone, not the loss function. Assuming that you're married to a batch size of 4, the final convolution to get to 4 x 200 x 500 x 500, crossentropy, and the backpropagation should only take maybe 10 GB, so cram your architecture into the remaining 30GB

import torch
x = torch.randn([4, 128, 500, 500]).cuda()
z = torch.nn.Conv2d(128, 200, 3)
z.cuda()
q = torch.randint(0, 200, (4, 498, 498)).cuda()
torch.nn.CrossEntropyLoss()(z(x), q).backward()

for example, takes 7.5 GB.

2

Scared_Employer6992 OP t1_ja7xpjt wrote

I haven't tried with bs=1, but I also don't want to use bs=1 as I usually get bad results with it and my net has a lot of BN layers.

0

badabummbadabing t1_ja7yxbg wrote

Don't use batch normalization. Lots of U-Nets use e.g. instance normalisation. A batch size of 1 should be completely fine (but you will need to play with the learning rate upon changing this). Check the 'no new U-Net' (aka NN-Unet) paper by Fabian Isensee for the definitive resource on what matters in U-Nets.

10