question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

T5 - Flax - Decreasing performance on pretraining

See original GitHub issue

Environment info

  • transformers version: 4.9.0.dev0
  • Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Flax version (CPU?/GPU?/TPU?): 0.3.4 (tpu)
  • Jax version: 0.2.17
  • JaxLib version: 0.68

Who can help

@patrickvonplaten

Script

Using a slightly modified version of run_t5_mlm_flax.py that also support streaming large datasets.

Information

I am posting this as a bug report, since the behaviour is counter intuitive. I am not sure if this is a bug with jax/T5, or if it is actually a behaviour that should be expected from T5.

We are training T5-base (v1.1) on a large, cleaned 250GB Norwegian dataset. We are training from 1M steps, which should equal roughly two complete epochs. With a lr=8e-3, bs=32, seq_length=512, adafactor, we are experiencing a steady decay in loss: image

The image above shows the first 250k steps. We needed to restart here, so I have not patched the event-files together. But the final loss after 1M steps ends on 1.349. Eval accuracy is also increasing.

The weird thing is that the final checkpoint has really terrible performance on our downstream task!

Looking into this issue, we evaluated multiple pre-training steps, by finetuning each of them 60k steps on a task of translating between two Norwegian dialects.

image

The red and blue dots are two models done before and after the t5 optimisation submitted by @patrickvonplaten.

The tendency here is very clear. After roughly 200k steps the model starts to suddenly perform worse on the downstream task, even if the loss is decreasing and the eval accuracy of the pretrained model in improving. The detonation happens before 1 epoch of the pretrain dataset, and though it looks like over-fitting, we find this extremely unlikely.

We have more experience with BERT-like models, and here performance on downstream tasks always improves as long as MLM-accuracy is improving. Is this expected behaviour of T5?

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:15 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
patrickvonplatencommented, Nov 2, 2021

I don’t think that adding or not adding the EOS token makes a difference - so I highly doubt that is the reason for your observations…BTW here is the original preprocessing code for pretraining: https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py#L1864

0reactions
github-actions[bot]commented, Nov 26, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Pre-training Dutch T5 Models - Hugging Face
The Dutch and Dutch+English T5 models are pre-trained with the masked language modeling (MLM) "span corruption" objective. During pre-training, 15% of the ...
Read more >
Pretrain and Fine-tune a T5 model with Flax on GCP - GitHub
A benchmark conducted by Huggingface showed that a BERT pretraining with Flax takes only 65% of the time needed with Pytorch/XLA to achieve...
Read more >
Training Text-to-Text Transformers with Privacy Guarantees
We also show that fully pri- vate T5 models exhibit reasonable pre-training per- formance and don't hurt subsequent fine-tuning, and that ...
Read more >
UL2: Unifying Language Learning Paradigms - arXiv
At the core of UL2 is a the newly proposed Mixture-of-Denoisers (MoD), a pre-training objective that enables strong performance across tasks.
Read more >
Google's Universal Pretraining Framework Unifies Language ...
Although today's pretrained large language models (LMs) continue to push the ... the performance of T5-XXL on one-shot summarization tasks.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found