question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Cannot load optimizer and lr_scheduler states with TPU training

See original GitHub issue

🐛 Bug

When restarting training and loading the optimizer.pt and scheduler.pt, the training crashes as the existing code does not know how to load it with TPU.

Information

The stacktrace -

Exception in device=TPU:5: don't know how to restore data location of torch.FloatStorage (tagged with xla:0)
Traceback (most recent call last):
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 231, in _start_fn
    fn(gindex, *args)
  File "/home/saurabh/<retracted>", line 334, in _mp_fn
    main()
  File "/home/saurabh/<retracted>", line 303, in main
    trainer.train(model_path=model_path)
  File "/home/saurabh/venv/lib/python3.6/site-packages/transformers/trainer.py", line 386, in train
    torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 584, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 764, in _legacy_load
    result = unpickler.load()
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 720, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 802, in restore_location
    return default_restore_location(storage, str(map_location))
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 179, in default_restore_location
    + location + ")")
RuntimeError: don't know how to restore data location of torch.FloatStorage (tagged with xla:0)

This happens when loading a partially trained model.

A reference implementation is this https://github.com/pytorch-tpu/fairseq/blob/tpu/fairseq/trainer.py#L195 With a discussion here https://github.com/pytorch/xla/issues/1343 Model I am using (Bert, XLNet …): any model

Language I am using the model on (English, Chinese …):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Train any model on TPU, wait for a checkpoint to happen
  2. move the tokenizer files to the checkpoint dir (another bug, where the trainer expects the tokenizer configs to be present at the same directory as checkpoint dir, that only happens at the very end of training, not at one of the earlier checkpoints)
  3. Restart training again from the checkpoint on TPU

Expected behavior

Trainer loads the optimizer and scheduler to TPU and starts training.

Environment info

  • transformers version: 2.11.0 (master)
  • Platform: Linux-5.3.0-1026-gcp-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.6.0a0+6bdfd6a (False)
  • Tensorflow version (GPU?): 2.2.0 (False)
  • Using GPU in script?: False
  • Using distributed or parallel set-up in script?: yes, 8 way with xla_spawn.py

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (1 by maintainers)

github_iconTop GitHub Comments

2reactions
Rashwancommented, Oct 15, 2020

@LysandreJik Any updates on this bug? this prevents resuming training from a checkpoint on TPUs

2reactions
abhi1nandy2commented, Sep 20, 2020

I encountered the same issue which I found to be due to the fact that the script cannot map the optimizer to the proper tpu device, here’s the line in question: https://github.com/huggingface/transformers/blob/d088d744adb4e5aa45262a34acab3ae9e81de169/src/transformers/trainer.py#L403

My solution was to replace

optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

by:

if is_torch_tpu_available():
    
    # load state_dict on CPU and then transfer object to xla device
    optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
    xm.send_cpu_data_to_device(optimizer,xm.xla_device())
else:
    optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

that seemed to have done the trick with torch-xla-nightly. hope this helps

This works, however, the progress bar starts from 0, and then, just takes a load of time to come to the step where the checkpoint is present! How to tackle that? I am training on the cloud (tpu v 3.8) and using xla_spawn script to distribute training among cores

Read more comments on GitHub >

github_iconTop Results From Across the Web

Troubleshooting TensorFlow - TPU - Google Cloud
This guide, along with the FAQ, provides troubleshooting help for users who are training TensorFlow models on Cloud TPU. If you are troubleshooting...
Read more >
Quick tour - Hugging Face
Saving/loading entire states. When training your model, you may want to save the current state of the model, optimizer, random generators, and potentially...
Read more >
Fine-Tuning Scheduler - PyTorch Lightning - Read the Docs
A FinetuningScheduler training session completes when the final phase of the schedule ... to maintain schedule state with special metadata.
Read more >
PyTorch TPU starter - DeBERTa-v3-large (training) | Kaggle
... torch.tensor([[-3, -2, -1], [0, 1, 2]]) 2022-06-07 10:51:57.643248: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load ...
Read more >
Tutorial: Fine-tuning BERT for Sentiment Analysis - by Skim AI
We will load the train data and label it. ... initialize_model(epochs=4): """Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found