FLAX core dump error on CloudTPU when running run_clm_flax.py
See original GitHub issueHi, I’m having a weird problem trying to train a gpt-neo model from scratch on a v3-8 cloud TPU. Something similar to the closed issue here. Getting:
https://symbolize.stripped_domain/r/?trace=7fb5dbf8a3f4,7fb5dbfe020f,7f&map=
*** SIGTERM received by PID 64823 (TID 64823) on cpu 26 from PID 63364; stack trace: *** | 0/1 [00:00<?, ?ba/s]
PC: @ 0x7fb5dbf8a3f4 (unknown) do_futex_wait.constprop.0
@ 0x7fb52fa377ed 976 (unknown)
@ 0x7fb5dbfe0210 440138896 (unknown) | 0/1 [00:00<?, ?ba/s]
@ 0x80 (unknown) (unknown) | 0/1 [00:00<?, ?ba/s]
https://symbolize.stripped_domain/r/?trace=7fb5dbf8a3f4,7fb52fa377ec,7fb5dbfe020f,7f&map=44c8b163be936ec2996e56972aa94d48:7fb521e7d000-7fb52fd90330
E1122 14:13:36.933620 64823 coredump_hook.cc:255] RAW: Remote crash gathering disabled for SIGTERM. | 0/1 [00:00<?, ?ba/s]
E1122 14:13:36.960024 64823 process_state.cc:776] RAW: Raising signal 15 with default behavior
randomly during preprocessing/loading the dataset.
The env is clean, setup according to the Quickstart Flax guide from google’s help page, and as well from here. Jax is installed okay, sees 8 TPUs. I tried the standard pip install as well as the local install as some people suggested in the issue above, still getting the same behavior.
This error does not kill the training. So, question number 1 would be how to get rid of this error ?
Something else happens that might be related: Running a dummy 300MB Wiki dataset for training only produces the error above, but training progresses. However, when running the full 40GB dataset, at a point during the first epoch I get:
list([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, .... (many 1s) .. 1, 1, 1])]]' of type <class 'numpy.ndarray'> is not a valid JAX type.
This error kills the training. I’ve found this related issue, but the last suggestion of increasing max_seq_len
does not apply here, as the preprocessing should automatically concatenate and cut the model len (and it is set in the config file). The dataset itself is clean, does not contain long words or chars or anything weird.
Thus, question 2: Any pointers on how to solve this second error?
Unfortunately I cannot share the dataset as it’s private 😞 so I don’t know how to help reproduce this error. There are 2 questions in this single issue as maybe there’s a chance they are related (?).
Thanks a bunch!
Update: here is the output of the run_clm_flax.py. Because there’s a limit on how much you can paste online, I’ve deleted a few chunks of repeating lines in the output.
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
I also had the same issue with another dataset and t5 model training. This problem seems to be related to datasets because I cut out the code of t5 training except for the data generation part, and I had the same “SIGTERM” error on TPU V4 VM.
I have tested it with Python 3.8 and python 3.7, and the same error occurs.
@stefan-it @dumitrescustefan did you find a solution rather than setting preprocessing_num_workers to 1 because it is extremely slow?
@patil-suraj Is there any solution to this problem?
I’ve re-opened that issue, because I’ve seen this problem over a long time. The reported error still occurs in latest Datasets and Transformers version on TPU.