question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItΒ collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Flax BART training fails when evaluating

See original GitHub issue

System Info

  • transformers version: 4.22.0.dev0
  • Platform: Linux-5.15.0-41-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.9.1
  • Flax version (CPU?/GPU?/TPU?): 0.6.0 (gpu)
  • Jax version: 0.3.17
  • JaxLib version: 0.3.15

Who can help?

@sgugger @patil-suraj

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

  1. make a dir "./my_bart_model
  2. Train a tokenizer (let’s use a small, Dutch corpus). Note: the repo README uses tokenizer.save but that only saves tokenizer.config and not the merges, so I think this is a second issue that should be fixed. Below I use save_model instead.
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer

# load dataset
dataset = load_dataset("dbrd", "plain_text", split="train")

# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()

def batch_iterator(batch_size=1000):
    for i in range(0, len(dataset), batch_size):
        yield dataset[i: i + batch_size]["text"]

# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

# Save files to disk
tokenizer.save_model("./my_bart_model")
  1. Create a BART config for it
from transformers import BartConfig
config = BartConfig.from_pretrained("facebook/bart-base", vocab_size=50265)
config.save_pretrained("./my_bart_model")
  1. Train the model with a quick evaluation (command from the root of the transformers lib)
python examples/flax/language-modeling/run_bart_dlm_flax.py --output_dir ./my_bart_model --config_name ./my_bart_model --tokenizer_name ./my_bart_model --dataset_name dbrd --dataset_config_name plain_text --max_seq_length 128 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --learning_rate 1e-4 --warmup_steps 100 --overwrite_output_dir --logging_steps 200 --save_steps 500 --eval_steps 200

This leads to the following error (also note the VisibleDeprecation, although that might be unrelated to the triggered error):

transformers/examples/flax/language-modeling/run_bart_dlm_flax.py:288: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}

Evaluating ...:  99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 77/78 [00:08<00:00,  9.03it/s]Training...:  14%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                         | 200/1429 [00:54<05:32,  3.69it/s]Epoch ... :   0%|                                                                                 | 0/3 [00:54<?, ?it/s]Traceback (most recent call last):
  File "transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 964, in <module>
    main()
  File "transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 896, in main
    model_inputs = data_collator(samples)
  File "transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 291, in __call__
    batch["decoder_input_ids"] = shift_tokens_right(
  File "/home/bram/.local/share/virtualenvs/bart-tTDq1jwG/lib/python3.8/site-packages/transformers/models/bart/modeling_flax_bart.py", line 228, in shift_tokens_right
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

Expected behavior

No errors and preferably no deprecation warnings.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
sanchit-gandhicommented, Sep 6, 2022

From reading the code, all blocks should be of size sequence length and the small remainder dropped:

Certainly, that should ideally be the case!

By the way, also getting these UserWarnings:

Are those UserWarnings being thrown in the parameter update step? It suggests to me a mis-match between the update and parameter dtypes!

I am working on transposing fairseq’s data implementation to PyTorch and adding a full training example in transformers.

More than happy to perform a code-review when finished!

Also cc’ing @duongna21 who must take all credit for implementing the Flax BART training example!

0reactions
BramVanroycommented, Sep 21, 2022

@duongna21 Thanks for chiming in! I am going to close this as I am not working on this directly any more. I assume that drop_last should indeed fix the issue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

https://huggingface.co/spaces/flax-community/dalle...
"Don't set if you want to train a model from scratch. ... class CustomFlaxBartModule(FlaxBartModule): def setup(self): # we keep shared to easily loadΒ ......
Read more >
Day 2 Talks: JAX, Flax & Transformers - YouTube
Day 2 Talks: JAX, Flax & Transformers 0:00:00 Suraj Patil & Patrick von Platen (Hugging Face): How to use JAX/ Flax withΒ ...
Read more >
Flax Integration for AI2 Tango
The Flax Train and Eval steps require you to define a wrapper class that will contain helper functions to compute the loss and...
Read more >
Clinician's Guide to Assessing and Counseling Older Drivers
Motor vehicle crashes are far more harmful for older adults than for all other age ... driving evaluation and/or training in use of...
Read more >
huggingface/transformers: v4.5.0: BigBird, GPT Neo, Examples, Flax ...
Raw training loop examples Based on the accelerate library, ... [Example] Fixed finename for Saving null_odds in the evaluation stage in QA Examples...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found