question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unable to run model parallel training using jax on TPU-VM

See original GitHub issue

Environment info

  • transformers version: 4.9.0.dev0
  • Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.9.0+cpu (False)
  • Tensorflow version (GPU?): 2.7.0-dev20210705 (False)
  • Flax version: 0.3.4 (tpu)
  • Jax version: 0.2.16
  • JaxLib version: 0.1.68
  • Using distributed or parallel set-up in script?: Yes

Who can help

examples/research_projects/jax/model_parallel @patil-suraj

Information

Model I am using GPTNeo-1.3B (for instance the one with resized to multiple of 8 embedding can be found here)

The problem arises when using:

  • the official example scripts

  • my own modified scripts: Same error is observed with customized script

To reproduce

Run the command below in examples/research_projects/jax-projects/model parallel folder in cloned tarnsformers repo:

python run_clm_mp.py \
    --model_name_or_path flax-community/gpt-neo-1.3B-resized-embed  \
    --tokenizer_name gpt2 \
    --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
    --do_train  --do_eval \
    --block_size 1024 \
    --num_train_epochs 5 \
    --learning_rate 4e-6 \
    --per_device_train_batch_size 3 --per_device_eval_batch_size 3 \
    --overwrite_output_dir --output_dir ~/tmp/flax-clm \
    --cache_dir ~/datasets_cache/wikitext --dtype bfloat16 \
    --logging_steps 96 --eval_steps 96

Stack trace:

07/16/2021 13:59:13 - INFO - absl - A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
/home/arto/jenv/lib/python3.8/site-packages/jax/experimental/pjit.py:160: UserWarning: pjit is an experimental feature and probably has bugs!
  warn("pjit is an experimental feature and probably has bugs!")
07/16/2021 13:59:21 - INFO - __main__ - ***** Running training *****
07/16/2021 13:59:21 - INFO - __main__ -   Num examples = 2318
07/16/2021 13:59:21 - INFO - __main__ -   Num Epochs = 5
07/16/2021 13:59:21 - INFO - __main__ -   Instantaneous batch size per device = 3
07/16/2021 13:59:21 - INFO - __main__ -   Total train batch size (w. parallel & distributed) = 24
07/16/2021 13:59:21 - INFO - __main__ -   Total optimization steps = 480
Epoch ... (1/5):   0%|                                                                                                                    | 0/5 [00:00<?, ?it/sF0716 13:59:49.611617   14290 array.h:414] Check failed: n < sizes_size                                                                   | 0/96 [00:00<?, ?it/s]
*** Check failure stack trace: ***
    @     0x7efd6d030347  (unknown)
    @     0x7efd6d02eed4  (unknown)
    @     0x7efd6d02e9c3  (unknown)
    @     0x7efd6d030cc9  (unknown)
    @     0x7efd68c98eee  (unknown)
    @     0x7efd68c2bb2f  (unknown)
    @     0x7efd68c29cc2  (unknown)
    @     0x7efd6c7aedb4  (unknown)
    @     0x7efd6c7b0212  (unknown)
    @     0x7efd6c7ade23  (unknown)
    @     0x7efd62c0956f  (unknown)
    @     0x7efd68c54248  (unknown)
    @     0x7efd68c55d2b  (unknown)
    @     0x7efd687a302b  (unknown)
    @     0x7efd68c94001  (unknown)
    @     0x7efd68c91d6a  (unknown)
    @     0x7efd68c918bd  (unknown)
    @     0x7efd68c94001  (unknown)
    @     0x7efd68c91d6a  (unknown)
    @     0x7efd68c918bd  (unknown)
    @     0x7efd6831013f  (unknown)
    @     0x7efd6830b52e  (unknown)
    @     0x7efd68315292  (unknown)
    @     0x7efd68322ffd  (unknown)
    @     0x7efd67f0d6b6  (unknown)
    @     0x7efd67f0d014  TpuCompiler_Compile
    @     0x7efd73180956  xla::(anonymous namespace)::TpuCompiler::Compile()
    @     0x7efd709300d4  xla::Service::BuildExecutables()
    @     0x7efd709261a0  xla::LocalService::CompileExecutables()
    @     0x7efd7086ae07  xla::LocalClient::Compile()
    @     0x7efd708452a0  xla::PjRtStreamExecutorClient::Compile()
    @     0x7efd6e440152  xla::PyClient::Compile()
    @     0x7efd6e1ba5e2  pybind11::detail::argument_loader<>::call_impl<>()
    @     0x7efd6e1baa51  pybind11::cpp_function::initialize<>()::{lambda()#3}::operator()()
    @     0x7efd6e1a1460  pybind11::cpp_function::dispatcher()
    @           0x5f2cc9  PyCFunction_Call
https://symbolize.stripped_domain/r/?trace=7efd6d030347,7efd6d02eed3,7efd6d02e9c2,7efd6d030cc8,7efd68c98eed,7efd68c2bb2e,7efd68c29cc1,7efd6c7aedb3,7efd6c7b0211,7efd6c7ade22,7efd62c0956e,7efd68c54247,7efd68c55d2a,7efd687a302a,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd6831013e,7efd6830b52d,7efd68315291,7efd68322ffc,7efd67f0d6b5,7efd67f0d013,7efd73180955,7efd709300d3,7efd7092619f,7efd7086ae06,7efd7084529f,7efd6e440151,7efd6e1ba5e1,7efd6e1baa50,7efd6e1a145f,5f2cc8&map=20957999b35a518f734e5552ed1ebec946aa0e35:7efd6db3c000-7efd74a2efc0,2a762cd764e70bc90ae4c7f9747c08d7:7efd600de000-7efd6d35f280 
https://symbolize.stripped_domain/r/?trace=7eff9cd0b18b,7eff9cd0b20f,7efd6d030487,7efd6d02eed3,7efd6d02e9c2,7efd6d030cc8,7efd68c98eed,7efd68c2bb2e,7efd68c29cc1,7efd6c7aedb3,7efd6c7b0211,7efd6c7ade22,7efd62c0956e,7efd68c54247,7efd68c55d2a,7efd687a302a,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd68c94000,7efd68c91d69,7efd68c918bc,7efd6831013e,7efd6830b52d,7efd68315291,7efd68322ffc,7efd67f0d6b5,7efd67f0d013,7efd73180955,7efd709300d3,7efd7092619f,7efd7086ae06,7efd7084529f&map=20957999b35a518f734e5552ed1ebec946aa0e35:7efd6db3c000-7efd74a2efc0,2a762cd764e70bc90ae4c7f9747c08d7:7efd600de000-7efd6d35f280 
*** SIGABRT received by PID 14290 (TID 14290) on cpu 89 from PID 14290; ***
E0716 13:59:49.681807   14290 coredump_hook.cc:292] RAW: Remote crash data gathering hook invoked.
E0716 13:59:49.681854   14290 coredump_hook.cc:384] RAW: Skipping coredump since rlimit was 0 at process start.
E0716 13:59:49.681862   14290 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0716 13:59:49.681870   14290 coredump_hook.cc:447] RAW: Sending fingerprint to remote end.
E0716 13:59:49.681876   14290 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0716 13:59:49.681886   14290 coredump_hook.cc:451] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0716 13:59:49.681891   14290 coredump_hook.cc:525] RAW: Discarding core.
F0716 13:59:49.611617   14290 array.h:414] Check failed: n < sizes_size 
E0716 13:59:49.953522   14290 process_state.cc:771] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

Expected behavior

Training in model parallel mode.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
skyecommented, Jul 26, 2021

Thanks for the very clear and detailed report!

This looks like a JAX bug, JAX should never abort like this. This line: F0716 13:59:49.611617 14290 array.h:414] Check failed: n < sizes_size indicates that this CHECK is failing: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/array.h#L414

I’ll try the repro and see if I can figure out what’s going on here.

0reactions
github-actions[bot]commented, Oct 5, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Using JAX in multi-host and multi-process environments
This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple...
Read more >
Scalable Training of Language Models using JAX pjit and ...
The pipeline- parallel training is optimized with GPipe (Huang et al.,. 2019), with optimal micro-batch size, while the tensor- parallel training uses the ......
Read more >
Introducing Cloud TPU VMs | Google Cloud Blog
You can get up and running quickly and start training ML models using JAX, PyTorch, and TensorFlow using Cloud TPUs and Cloud TPU...
Read more >
Day 1 Talks: JAX, Flax & Transformers - YouTube
Day 1 Talks: JAX, Flax & Transformers 0:00:00 Skye Wanderman-Milne (Google Brain): Intro to JAX on Cloud TPUs0:42:49 Marc van Zee (Google ...
Read more >
arXiv:2204.06514v1 [cs.LG] 13 Apr 2022
Scalable Training of Language Models using JAX pjit and TPUv4. Joanna Yoo∗ ... highlighting the use of recently released TPU v4.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found