Submitted by thomasahle t3_118gie9 in MachineLearning
Cross entropy on logits is a normal simplification that fuses softmax + cross entropy loss to something like:
def label_cross_entropy_on_logits(x, labels):
return (-x.select(labels) + x.logsumexp(axis=1)).sum(axis=0)
where x.select(labels) = x[range(batch_size), labels]
.
I was thinking about how the logsumexp
term looks like a regularization term, and wondered what would happen if I just replaced it by x.norm(axis=1)
instead. It seemed to work just as well as the original, so I thought, why not just enforce unit norm?
I changed my code to
def label_cross_entropy_on_logits(x, labels):
return -(x.select(labels) / x.norm(axis=1)).sum(axis=0)
and my training sped up dramatically, and my test loss decreased.
I'm sure this is a standard approach to categorical loss, but I haven't seen it before, and would love to get some references.
I found this old post: https://www.reddit.com/r/MachineLearning/comments/k6ff4w/unit_normalization_crossentropy_loss_outperforms/ which references LogitNormalization: https://arxiv.org/pdf/2205.09310.pdf However, it seems those papers all apply layer normalization and then softmax+CE. What seems to work for me is simply replacing softmax+CE by normalization.
cthorrez t1_j9iq35y wrote
> test loss decreased
What function are you using to evaluate test loss? cross entropy or this norm function?