Unable to run model parallel training using jax on TPU-VM
See original GitHub issueEnvironment info
transformers
version: 4.9.0.dev0- Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29
- Python version: 3.8.10
- PyTorch version (GPU?): 1.9.0+cpu (False)
- Tensorflow version (GPU?): 2.7.0-dev20210705 (False)
- Flax version: 0.3.4 (tpu)
- Jax version: 0.2.16
- JaxLib version: 0.1.68
- Using distributed or parallel set-up in script?: Yes
Who can help
examples/research_projects/jax/model_parallel @patil-suraj
Information
Model I am using GPTNeo-1.3B (for instance the one with resized to multiple of 8 embedding can be found here)
The problem arises when using:
-
the official example scripts
-
my own modified scripts: Same error is observed with customized script
To reproduce
Run the command below in examples/research_projects/jax-projects/model parallel
folder in cloned tarnsformers repo:
python run_clm_mp.py \
--model_name_or_path flax-community/gpt-neo-1.3B-resized-embed \
--tokenizer_name gpt2 \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--do_train --do_eval \
--block_size 1024 \
--num_train_epochs 5 \
--learning_rate 4e-6 \
--per_device_train_batch_size 3 --per_device_eval_batch_size 3 \
--overwrite_output_dir --output_dir ~/tmp/flax-clm \
--cache_dir ~/datasets_cache/wikitext --dtype bfloat16 \
--logging_steps 96 --eval_steps 96
Stack trace:
07/16/2021 13:59:13 - INFO - absl - A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
/home/arto/jenv/lib/python3.8/site-packages/jax/experimental/pjit.py:160: UserWarning: pjit is an experimental feature and probably has bugs!
warn("pjit is an experimental feature and probably has bugs!")
07/16/2021 13:59:21 - INFO - __main__ - ***** Running training *****
07/16/2021 13:59:21 - INFO - __main__ - Num examples = 2318
07/16/2021 13:59:21 - INFO - __main__ - Num Epochs = 5
07/16/2021 13:59:21 - INFO - __main__ - Instantaneous batch size per device = 3
07/16/2021 13:59:21 - INFO - __main__ - Total train batch size (w. parallel & distributed) = 24
07/16/2021 13:59:21 - INFO - __main__ - Total optimization steps = 480
Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/sF0716 13:59:49.611617 14290 array.h:414] Check failed: n < sizes_size | 0/96 [00:00<?, ?it/s]
*** Check failure stack trace: ***
@ 0x7efd6d030347 (unknown)
@ 0x7efd6d02eed4 (unknown)
@ 0x7efd6d02e9c3 (unknown)
@ 0x7efd6d030cc9 (unknown)
@ 0x7efd68c98eee (unknown)
@ 0x7efd68c2bb2f (unknown)
@ 0x7efd68c29cc2 (unknown)
@ 0x7efd6c7aedb4 (unknown)
@ 0x7efd6c7b0212 (unknown)
@ 0x7efd6c7ade23 (unknown)
@ 0x7efd62c0956f (unknown)
@ 0x7efd68c54248 (unknown)
@ 0x7efd68c55d2b (unknown)
@ 0x7efd687a302b (unknown)
@ 0x7efd68c94001 (unknown)
@ 0x7efd68c91d6a (unknown)
@ 0x7efd68c918bd (unknown)
@ 0x7efd68c94001 (unknown)
@ 0x7efd68c91d6a (unknown)
@ 0x7efd68c918bd (unknown)
@ 0x7efd6831013f (unknown)
@ 0x7efd6830b52e (unknown)
@ 0x7efd68315292 (unknown)
@ 0x7efd68322ffd (unknown)
@ 0x7efd67f0d6b6 (unknown)
@ 0x7efd67f0d014 TpuCompiler_Compile
@ 0x7efd73180956 xla::(anonymous namespace)::TpuCompiler::Compile()
@ 0x7efd709300d4 xla::Service::BuildExecutables()
@ 0x7efd709261a0 xla::LocalService::CompileExecutables()
@ 0x7efd7086ae07 xla::LocalClient::Compile()
@ 0x7efd708452a0 xla::PjRtStreamExecutorClient::Compile()
@ 0x7efd6e440152 xla::PyClient::Compile()
@ 0x7efd6e1ba5e2 pybind11::detail::argument_loader<>::call_impl<>()
@ 0x7efd6e1baa51 pybind11::cpp_function::initialize<>()::{lambda()#3}::operator()()
@ 0x7efd6e1a1460 pybind11::cpp_function::dispatcher()
@ 0x5f2cc9 PyCFunction_Call
https://symbolize.stripped_domain/r/?trace=7efd6d030347,7efd6d02eed3,7efd6d02e9c2,7efd6d030cc8,7efd68c98eed,7efd68c2bb2e,7efd68c29cc1,7efd6c7aedb3,7efd6c7b0211,7efd6c7ade22,7efd62c0956e,7efd68c54247,7efd68c55d2a,7efd687a302a,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd6831013e,7efd6830b52d,7efd68315291,7efd68322ffc,7efd67f0d6b5,7efd67f0d013,7efd73180955,7efd709300d3,7efd7092619f,7efd7086ae06,7efd7084529f,7efd6e440151,7efd6e1ba5e1,7efd6e1baa50,7efd6e1a145f,5f2cc8&map=20957999b35a518f734e5552ed1ebec946aa0e35:7efd6db3c000-7efd74a2efc0,2a762cd764e70bc90ae4c7f9747c08d7:7efd600de000-7efd6d35f280
https://symbolize.stripped_domain/r/?trace=7eff9cd0b18b,7eff9cd0b20f,7efd6d030487,7efd6d02eed3,7efd6d02e9c2,7efd6d030cc8,7efd68c98eed,7efd68c2bb2e,7efd68c29cc1,7efd6c7aedb3,7efd6c7b0211,7efd6c7ade22,7efd62c0956e,7efd68c54247,7efd68c55d2a,7efd687a302a,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd6831013e,7efd6830b52d,7efd68315291,7efd68322ffc,7efd67f0d6b5,7efd67f0d013,7efd73180955,7efd709300d3,7efd7092619f,7efd7086ae06,7efd7084529f&map=20957999b35a518f734e5552ed1ebec946aa0e35:7efd6db3c000-7efd74a2efc0,2a762cd764e70bc90ae4c7f9747c08d7:7efd600de000-7efd6d35f280
*** SIGABRT received by PID 14290 (TID 14290) on cpu 89 from PID 14290; ***
E0716 13:59:49.681807 14290 coredump_hook.cc:292] RAW: Remote crash data gathering hook invoked.
E0716 13:59:49.681854 14290 coredump_hook.cc:384] RAW: Skipping coredump since rlimit was 0 at process start.
E0716 13:59:49.681862 14290 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0716 13:59:49.681870 14290 coredump_hook.cc:447] RAW: Sending fingerprint to remote end.
E0716 13:59:49.681876 14290 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0716 13:59:49.681886 14290 coredump_hook.cc:451] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0716 13:59:49.681891 14290 coredump_hook.cc:525] RAW: Discarding core.
F0716 13:59:49.611617 14290 array.h:414] Check failed: n < sizes_size
E0716 13:59:49.953522 14290 process_state.cc:771] RAW: Raising signal 6 with default behavior
Aborted (core dumped)
Expected behavior
Training in model parallel mode.
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (6 by maintainers)
Top Results From Across the Web
Using JAX in multi-host and multi-process environments
This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple...
Read more >Scalable Training of Language Models using JAX pjit and ...
The pipeline- parallel training is optimized with GPipe (Huang et al.,. 2019), with optimal micro-batch size, while the tensor- parallel training uses the ......
Read more >Introducing Cloud TPU VMs | Google Cloud Blog
You can get up and running quickly and start training ML models using JAX, PyTorch, and TensorFlow using Cloud TPUs and Cloud TPU...
Read more >Day 1 Talks: JAX, Flax & Transformers - YouTube
Day 1 Talks: JAX, Flax & Transformers 0:00:00 Skye Wanderman-Milne (Google Brain): Intro to JAX on Cloud TPUs0:42:49 Marc van Zee (Google ...
Read more >arXiv:2204.06514v1 [cs.LG] 13 Apr 2022
Scalable Training of Language Models using JAX pjit and TPUv4. Joanna Yoo∗ ... highlighting the use of recently released TPU v4.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thanks for the very clear and detailed report!
This looks like a JAX bug, JAX should never abort like this. This line:
F0716 13:59:49.611617 14290 array.h:414] Check failed: n < sizes_size
indicates that this CHECK is failing: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/array.h#L414I’ll try the repro and see if I can figure out what’s going on here.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.