Submitted by pgao_aquarium t3_11l4xo0 in MachineLearning
KD_A t1_jbb5kx5 wrote
The section "Check if your model is overfitting" could be improved.
> The model is overfitting (high variance) when it has low error on the training set but high error on the test set.
A big gap between training and validation error does not imply that it is overfitting. In general, an absolute gap between training and validation errors does not tell you how validation error will change if a model is made more complex or more simple. To answer questions about overfitting and underfitting, one needs to train multiple models and compare their training and validation errors.
> Overfitting and underfitting is easy to detect by visualizing loss curves during training.
nit: this caption is phrased too liberally, as the graph only answers this question: given this model architecture, optimizer, and dataset, which model epoch/checkpoint should I select? It does not tell you about any other factors which modulate model complexity.
> This often means that the training set is not representative of the domain it is supposed to run in.
I wouldn't call this a variance issue per se. If it were a variance issue, sampling more data from the training distribution should significantly lower validation error. If the training distribution is biased, sampling more of it will not help a whole lot.
That all being said, I share your passion for greater standardization of ML workflow. And I agree that there needs to be more work on diagnosing problems, and less "throwing stuff at the wall". To add something, I now typically run learning curves. They can cost quite a bit when training big NNs. But even a low-resolution curve can give a short term answer to an important question: how much should I expect this model to improve if I train it on n
more observations? And assuming you have a decent sense of your model's capacity, this question is closely related to another common one: should I prioritize collecting more data, or should I make a modeling intervention? Learning curves have motivated big improvements in my experience.
murrdpirate t1_jbd9755 wrote
>It does not tell you about any other factors which modulate model complexity.
Can you expand on that? My general understanding is that if I'm seeing significantly lower training losses than validation losses, then my model complexity is too high compared to the data (unless there's something wrong with the data).
KD_A t1_jbdi9x4 wrote
Notice that "significantly lower" can't actually be defined. There isn't a useful definition of overfitting which only requires observing a single model's train and test gap.^1 Here's a contrived but illustrative example: say model A
has a training error rate of 0.10 and test error rate of 0.30. It's tempting to think "test error is 3x train error, we're overfitting". This may or may not be right; there absolutely could be a (more complex) model B
with, e.g., training error rate 0.05, test error rate 0.27. Notice that the train-test gap increased going from A
to B
. But I don't care. Assuming these estimates are good, and all I care about is minimizing expected error rate, I'd confidently deploy model B
over model A
.
The useful definition of overfitting is that it refers to a region in function space where test error goes up and training error goes down (as model complexity goes up). Diagram (for underparametrized models). This definition tells us that the only good way to tell whether a model should be made more simple or more complex is to fit multiple models and compare them. This info is expensive to obtain for NNs, and obtaining it makes one look less clever. But it gives a reliable hint as to how a model should be iterated.^2 In the example above, if we really did observe that model B
, then perhaps our next one should be even more complex.
If you're asking more specifically about reading NN loss curves, I haven't seen any science which puts claims like #4 here to the test.^3 I'd also like to mention another common issue w/ reading NN loss curves: people usually don't take care in estimating training loss. The standard NN training loop results in overestimates, which will make the gap between training and validation appear bigger than it actually is. I happened to write about this problem today, here in CrossValidated.
Footnotes
-
2 exceptions to this: (1) you're already quite familiar w/ the type of task and data, so you can correlate high gaps with overfitting based on previous experience (2) test error is higher than an intercept-only model, and training error is much lower.
-
Double descent complicates this workflow. For overparametrized models like NNs, one can be deceived into not going far enough when increasing model complexity. Or it's difficult to determine whether a certain intervention is actually increasing or decreasing complexity. This paper characterizes various forms of double descent.
-
My answer to #4 would be the same as what I wrote when criticizing the caption in my first comment. The provided answer—"reduce model capacity"—is too vague. The answer should be: select the model checkpoint from halfway, simply b/c its test error is the lowest. The graph alone doesn't tell you anything about how the model should be iterated beyond that info. That's b/c each point on the curve is the loss after being trained for n iterations, conditional on all of the other factors which modulate the model's complexity. There could absolutely be a model w/ more depth, more width, etc. which performs better than the simpler model trained halfway.
murrdpirate t1_jben5uy wrote
>Notice that "significantly lower" can't actually be defined.
True. I guess I would say that over-fitting is a spectrum, and that there's generally some amount of over-fitting happening (unless your training set happens to be significantly more challenging than your test set). So the bigger the gap between train and test, the more over-fitting.
>It's tempting to think "test error is 3x train error, we're overfitting". This may or may not be right; there absolutely could be a (more complex) model B with, e.g., training error rate 0.05, test error rate 0.27.
Maybe it's semantics, but in my view, I would say model B is indeed overfitting "more" than model A. But I don't think more overfitting guarantees worse test results, it just increases the likelihood of worse test results due to increased variance. I may still choose to deploy model B, but I would view it as a highly overfitting model that happened to perform well.
Appreciate the response. I also liked your CrossValidated post. I've wondered about that issue myself. Do you think data augmentation should also be disabled in that test?
KD_A t1_jbf175s wrote
> Do you think data augmentation should also be disabled in that test?
Yes. I've never actually experimented w/ stuff like image augmentation. But in most examples I looked up, augmentation is a training-only computation which may make training loss look higher than it actually is. In general the rule is just this: to unbiasedly estimate training loss, apply the exact same code you're using to estimate validation loss to training data.
Viewing a single comment thread. View all comments