question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Tensor device error when computing kl_div in VAE

See original GitHub issue

Hello, @lucidrains 😃

I saw https://github.com/lucidrains/DALLE-pytorch/commit/706f06db0c8cede9fe02d16cbfe8ee09144858ea this commit, and I got below error when I trained with it.

Traceback (most recent call last):
  File "/root/.pycharm_helpers/pydev/pydevd.py", line 1434, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/train_VAE.py", line 237, in <module>
    main()
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/train_VAE.py", line 164, in main
    loss = get_loss(images)
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/train_VAE.py", line 145, in get_loss
    loss = F.smooth_l1_loss(images, recons) + vae(images, return_loss=True)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/shared/workspace/torch_research/text-to-image/dalle-pytorch/models/model_arch.py", line 173, in forward
    kl_div = (qy * (log_qy - g)).sum(dim = (1, 2)).mean()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Reason is the device of g construction is not adaptable to the current device.

https://github.com/lucidrains/DALLE-pytorch/blob/706f06db0c8cede9fe02d16cbfe8ee09144858ea/dalle_pytorch/dalle_pytorch.py#L172

This can be solved by

device = img.device
....
g = torch.log(torch.tensor([1. / num_tokens], device=device))

Please check if it is a correct way to fix 😃

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:1
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
lucidrainscommented, Feb 17, 2021

@johnmccain @kobiso i did correct it in the latest version! https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py#L172 let me push out another patch just in case

1reaction
kobisocommented, Feb 16, 2021

Thanks for quick revision!

Read more comments on GitHub >

github_iconTop Results From Across the Web

how to weight KLD loss vs reconstruction loss in variational ...
The issue I'm having, is that if the balance between my input feature (x) dimensions and latent space (z) dimensions is not 'optimum',...
Read more >
Why amI getting RuntimeError: CUDA error: device-side assert ...
I tried to implement a simple VAE! but I'm getting this error out of no ... eps = torch.tensor(std.data.new(std.size()).normal_(0,1)) return ...
Read more >
CUDA Error: Device-Side Assert Triggered: Solved | Built In
A CUDA Error: Device-Side Assert Triggered can either be caused by an inconsistency between the number of labels and output units or an ......
Read more >
Moving member tensors with module.to() in PyTorch
When I now move the autoencoder to the GPU with net.to('cuda:0') I get an error in forwarding because the noise tensor is not...
Read more >
On the use of the Kullback–Leibler divergence in Variational ...
If you don't know what is a VAE, you could start by giving a look at that ... Arguments args (tensor): mean and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found