tensor shape dismatched when computing batch sinkhorn loss
See original GitHub issuewhen I try to compute the loss of a batch data, I met this bug.
import torch
from geomloss import SamplesLoss # See also ImagesLoss, VolumesLoss
cuda_device = torch.device("cuda:%d" % 0 if torch.cuda.is_available() else "cpu")
x=torch.randn(100,90,400, requires_grad=True).to(cuda_device)
y=torch.randn(100,90,400).to(cuda_device)
# Define a Sinkhorn (~Wasserstein) loss between sampled measures
loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
L = loss(x, y) # By default, use constant weights = 1/number of samples
print(L.item())
g_x, = torch.autograd.grad(L, [x]) # GeomLoss fully supports autograd!
print(g_x)
output
Traceback (most recent call last):
File "/home/lowen/program/pycharm-community-2018.2.3/helpers/pydev/pydevd.py", line 1664, in <module>
main()
File "/home/lowen/program/pycharm-community-2018.2.3/helpers/pydev/pydevd.py", line 1658, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "/home/lowen/program/pycharm-community-2018.2.3/helpers/pydev/pydevd.py", line 1068, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "/home/lowen/program/pycharm-community-2018.2.3/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/devdata/new_Relation_Extraction/test_geomloss.py", line 40, in <module>
L = loss(x, y) # By default, use constant weights = 1/number of samples
File "/home/lowen/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "/home/lowen/anaconda3/envs/pytorch/lib/python3.7/site-packages/geomloss/samples_loss.py", line 239, in forward
verbose=self.verbose)
File "/home/lowen/anaconda3/envs/pytorch/lib/python3.7/site-packages/geomloss/sinkhorn_samples.py", line 52, in sinkhorn_tensorized
C_xx, C_yy, C_xy, C_yx, ε_s, ρ, debias=debias )
File "/home/lowen/anaconda3/envs/pytorch/lib/python3.7/site-packages/geomloss/sinkhorn_divergence.py", line 162, in sinkhorn_loop
at_x = λ * softmin(ε, C_xx, α_log + a_x/ε ) # OT(α,α)
RuntimeError: The size of tensor a (90) must match the size of tensor b (9000) at non-singleton dimension 1
So I debug the code and found that a_x([9000] tensors) dismatched α_log ([100,90] tensors), then I try to apply a view() operation to α_log just like this
at_x = λ * softmin(ε, C_xx, α_log + (a_x / ε).view(α_log.size()))
After I fixed all this kind of bugs with a view() operation, it successfully returns a batch loss(in my case, that’s [100] tensors).
So I was wondering is it the right way to fix the bug? or there is a batter way to compute wasserstein loss for a batch data?
Thx!!!
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
Guide 3: Debugging in PyTorch - UvA DL Notebooks
The most common mistake is the mismatch between loss function and output ... occurs if you re-use a tensor from the computation graph...
Read more >Differential Properties of Sinkhorn Approximation for Learning ...
In this work we characterize the differential properties of the original Sinkhorn distance, proving that it enjoys the same smoothness as its ...
Read more >Sinkhorn AutoEncoders - DeepAI
We learn deterministic autoencoders by minimizing a reconstruction error and the Wasserstein distance on the latent space between samples of the ...
Read more >Approximating Wasserstein distances with PyTorch
After adding this change to the implementation (code here), we can compute Sinkhorn distances for multiple distributions in a mini-batch. Let's ...
Read more >Entropic Optimal Transport in Machine Learning
We present the first consistent estimator for learning with Sinkhorn loss in supervised settings, with explicit excess risk bounds. We propose a ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hi @heslowen , Indeed! In fact, I pushed this fix to master just three days ago in #9 : it is not yet available on PyPi, but will be up soon 😃 Best regards, Jean
You’re welcome 😃