EMA Bug
See original GitHub issueHi Phil,
This morning I tried to run the decoder training part. I decided to use DecoderTrainer
but found one issue when ema update.
When after using decoder_trainer do sampling, the next train forward run will throw RunError:
Traceback (most recent call last):
File "/home/caohe/DPMs/dalle2/train_decoder.py", line 321, in <module> main()
File "/home/caohe/DPMs/dalle2/train_decoder.py", line 318, in main
train(decoder_trainer, train_dl, val_dl, train_config, device)
File "/home/caohe/DPMs/dalle2/train_decoder.py", line 195, in train
trainer.update(unet_number)
File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 288, in update
self.ema_unets[index].update()
File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 119, in update
self.update_moving_average(self.ema_model, self.online_model)
File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 129, in update_moving_average
ema_param.data = calculate_ema(self.beta, old_weight, up_weight)
File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 125, in calculate_ema
return old * beta + new * (1 - beta)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU!
And I checked the up_weight.device
(online model) and old_weight.device
(ema model), found online model is on cuda:0
but ema model is on cpu
. It’s really weird, I debugged for a long time and I think it might be caused by the DecoderTrainer.sample()
process.
When swapping across ema and online model, there exists some problem related to the device.
https://github.com/lucidrains/DALLE2-pytorch/blob/6021945fc8e1ec27bbebfa1e181e892a7c4d05fb/dalle2_pytorch/train.py#L298-L308
The way I fixed it just add
self.ema_model = self.ema_model.to(next(self.online_model.parameters()).device)
before useself.update_moving_average(self.ema_model, self.online_model)
(pretty naive haha)
Hope to hear your solution
Enjoy!
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:5 (3 by maintainers)
Top GitHub Comments
@lucidrains lol. But when moving to cluster do train, things gonna be out of control sometimes (I hate bugs)
@CiaoHe i’ve come full circle and just use a simple
test.py
in the root directory +print
lol