question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Hi Phil,

This morning I tried to run the decoder training part. I decided to use DecoderTrainer but found one issue when ema update.

When after using decoder_trainer do sampling, the next train forward run will throw RunError:

Traceback (most recent call last):
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 321, in <module>    main()
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 318, in main
    train(decoder_trainer, train_dl, val_dl, train_config, device)
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 195, in train
    trainer.update(unet_number)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 288, in update
    self.ema_unets[index].update()
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 119, in update
    self.update_moving_average(self.ema_model, self.online_model)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 129, in update_moving_average
    ema_param.data = calculate_ema(self.beta, old_weight, up_weight)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 125, in calculate_ema
    return old * beta + new * (1 - beta)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU!

https://github.com/lucidrains/DALLE2-pytorch/blob/6f76652d118d3da2419bd12084abfff45772553b/dalle2_pytorch/train.py#L108-L118

And I checked the up_weight.device(online model) and old_weight.device(ema model), found online model is on cuda:0 but ema model is on cpu. It’s really weird, I debugged for a long time and I think it might be caused by the DecoderTrainer.sample() process. When swapping across ema and online model, there exists some problem related to the device. https://github.com/lucidrains/DALLE2-pytorch/blob/6021945fc8e1ec27bbebfa1e181e892a7c4d05fb/dalle2_pytorch/train.py#L298-L308

The way I fixed it just add self.ema_model = self.ema_model.to(next(self.online_model.parameters()).device) before use self.update_moving_average(self.ema_model, self.online_model) (pretty naive haha)

Hope to hear your solution

Enjoy!

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:1
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
CiaoHecommented, May 12, 2022

@CiaoHe i’ve come full circle and just use a simple test.py in the root directory + print lol

@lucidrains lol. But when moving to cluster do train, things gonna be out of control sometimes (I hate bugs)

1reaction
lucidrainscommented, May 12, 2022

@CiaoHe i’ve come full circle and just use a simple test.py in the root directory + print lol

Read more comments on GitHub >

github_iconTop Results From Across the Web

Ema Bug (@emabug626) • Instagram photos and videos
635 Followers, 1603 Following, 3631 Posts - See Instagram photos and videos from Ema Bug (@emabug626)
Read more >
Vague Google Assistant email warns of disabled routines, Hey ...
Some Google Assistant users have received an "action required" email warning about a bug that affected hotword and other service settings...
Read more >
MailBug™ | Email Simplified
MailBug is email simplified. It's designed for those who'd like to use email, but lack the technical skills and experience typically required.
Read more >
Ema Bug | Facebook
Ema Bug is on Facebook. Join Facebook to connect with Ema Bug and others you may know. Facebook gives people the power to...
Read more >
[BUG] Recommended way to implement EMA #2056 - GitHub
A clear and concise description of what the bug is. Hi deepspeed team, I have some code that uses exponential moving average (EMA)...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found