Flax BART training fails when evaluating
See original GitHub issueSystem Info
transformers
version: 4.22.0.dev0- Platform: Linux-5.15.0-41-generic-x86_64-with-glibc2.29
- Python version: 3.8.10
- Huggingface_hub version: 0.9.1
- Flax version (CPU?/GPU?/TPU?): 0.6.0 (gpu)
- Jax version: 0.3.17
- JaxLib version: 0.3.15
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, β¦) - My own task or dataset (give details below)
Reproduction
- make a dir
"./my_bart_model
- Train a tokenizer (letβs use a small, Dutch corpus). Note: the repo README uses
tokenizer.save
but that only saves tokenizer.config and not the merges, so I think this is a second issue that should be fixed. Below I usesave_model
instead.
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
# load dataset
dataset = load_dataset("dbrd", "plain_text", split="train")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
# Save files to disk
tokenizer.save_model("./my_bart_model")
- Create a BART config for it
from transformers import BartConfig
config = BartConfig.from_pretrained("facebook/bart-base", vocab_size=50265)
config.save_pretrained("./my_bart_model")
- Train the model with a quick evaluation (command from the root of the transformers lib)
python examples/flax/language-modeling/run_bart_dlm_flax.py --output_dir ./my_bart_model --config_name ./my_bart_model --tokenizer_name ./my_bart_model --dataset_name dbrd --dataset_config_name plain_text --max_seq_length 128 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --learning_rate 1e-4 --warmup_steps 100 --overwrite_output_dir --logging_steps 200 --save_steps 500 --eval_steps 200
This leads to the following error (also note the VisibleDeprecation, although that might be unrelated to the triggered error):
transformers/examples/flax/language-modeling/run_bart_dlm_flax.py:288: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
Evaluating ...: 99%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 77/78 [00:08<00:00, 9.03it/s]Training...: 14%|ββββββββββ | 200/1429 [00:54<05:32, 3.69it/s]Epoch ... : 0%| | 0/3 [00:54<?, ?it/s]Traceback (most recent call last):
File "transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 964, in <module>
main()
File "transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 896, in main
model_inputs = data_collator(samples)
File "transformers/examples/flax/language-modeling/run_bart_dlm_flax.py", line 291, in __call__
batch["decoder_input_ids"] = shift_tokens_right(
File "/home/bram/.local/share/virtualenvs/bart-tTDq1jwG/lib/python3.8/site-packages/transformers/models/bart/modeling_flax_bart.py", line 228, in shift_tokens_right
shifted_input_ids[:, 1:] = input_ids[:, :-1]
IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed
Expected behavior
No errors and preferably no deprecation warnings.
Issue Analytics
- State:
- Created a year ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
https://huggingface.co/spaces/flax-community/dalle...
"Don't set if you want to train a model from scratch. ... class CustomFlaxBartModule(FlaxBartModule): def setup(self): # we keep shared to easily loadΒ ......
Read more >Day 2 Talks: JAX, Flax & Transformers - YouTube
Day 2 Talks: JAX, Flax & Transformers 0:00:00 Suraj Patil & Patrick von Platen (Hugging Face): How to use JAX/ Flax withΒ ...
Read more >Flax Integration for AI2 Tango
The Flax Train and Eval steps require you to define a wrapper class that will contain helper functions to compute the loss and...
Read more >Clinician's Guide to Assessing and Counseling Older Drivers
Motor vehicle crashes are far more harmful for older adults than for all other age ... driving evaluation and/or training in use of...
Read more >huggingface/transformers: v4.5.0: BigBird, GPT Neo, Examples, Flax ...
Raw training loop examples Based on the accelerate library, ... [Example] Fixed finename for Saving null_odds in the evaluation stage in QA Examples...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Certainly, that should ideally be the case!
Are those UserWarnings being thrown in the parameter update step? It suggests to me a mis-match between the update and parameter dtypes!
More than happy to perform a code-review when finished!
Also ccβing @duongna21 who must take all credit for implementing the Flax BART training example!
@duongna21 Thanks for chiming in! I am going to close this as I am not working on this directly any more. I assume that drop_last should indeed fix the issue.