T5 - Flax - Decreasing performance on pretraining
See original GitHub issueEnvironment info
transformers
version: 4.9.0.dev0- Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29
- Python version: 3.8.10
- Flax version (CPU?/GPU?/TPU?): 0.3.4 (tpu)
- Jax version: 0.2.17
- JaxLib version: 0.68
Who can help
Script
Using a slightly modified version of run_t5_mlm_flax.py that also support streaming large datasets.
Information
I am posting this as a bug report, since the behaviour is counter intuitive. I am not sure if this is a bug with jax/T5, or if it is actually a behaviour that should be expected from T5.
We are training T5-base (v1.1) on a large, cleaned 250GB Norwegian dataset. We are training from 1M steps, which should equal roughly two complete epochs. With a lr=8e-3, bs=32, seq_length=512, adafactor, we are experiencing a steady decay in loss:
The image above shows the first 250k steps. We needed to restart here, so I have not patched the event-files together. But the final loss after 1M steps ends on 1.349. Eval accuracy is also increasing.
The weird thing is that the final checkpoint has really terrible performance on our downstream task!
Looking into this issue, we evaluated multiple pre-training steps, by finetuning each of them 60k steps on a task of translating between two Norwegian dialects.
The red and blue dots are two models done before and after the t5 optimisation submitted by @patrickvonplaten.
The tendency here is very clear. After roughly 200k steps the model starts to suddenly perform worse on the downstream task, even if the loss is decreasing and the eval accuracy of the pretrained model in improving. The detonation happens before 1 epoch of the pretrain dataset, and though it looks like over-fitting, we find this extremely unlikely.
We have more experience with BERT-like models, and here performance on downstream tasks always improves as long as MLM-accuracy is improving. Is this expected behaviour of T5?
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:15 (10 by maintainers)
I don’t think that adding or not adding the EOS token makes a difference - so I highly doubt that is the reason for your observations…BTW here is the original preprocessing code for pretraining: https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py#L1864
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.