[Bug] SumBatchLazyTensor size is inconsistent with indices
See original GitHub issue🐛 Bug
I want to do a KroneckerProductLazyTensor on a batch of lazyTensor x
times, then SumBatchLazyTensor and then get a specific row and finally evaluate. The code works if I first do an evaluation on the sum_a
then retrieve the row (which is inefficient) but gives size is inconsistent with indices
error if I retrieve the row first and then wants to evaluate.
Interestingly, If I use the same number for the dimension -1 and -2, there would be no error then.
To reproduce
** Code snippet to reproduce **
import gpytorch
x = 3
a = torch.rand((x, 5, 2, 3))
lazy_a = gpytorch.lazy.NonLazyTensor(a)
assert lazy_a.shape == torch.Size([3, 5, 2, 3])
prod_a = gpytorch.lazy.KroneckerProductLazyTensor(*lazy_a)
assert prod_a.shape == torch.Size([5, 8, 27])
sum_a = gpytorch.lazy.SumBatchLazyTensor(prod_a)
assert sum_a.shape == torch.Size([8, 27])
assert sum_a.evaluate()[0].shape == torch.Size([27])
assert sum_a[0].evaluate().shape == torch.Size([27]) # gives error in here
** Stack trace/error message **
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-251-7cce10ce99d3> in <module>()
8 assert sum_a.shape == torch.Size([8, 27])
9 assert sum_a.evaluate()[0].shape == torch.Size([27])
---> 10 assert sum_a[0].evaluate().shape == torch.Size([27])
11
9 frames
/usr/local/lib/python3.6/dist-packages/gpytorch/lazy/lazy_tensor.py in __getitem__(self, index)
1703 # with the appropriate shape
1704 if (squeeze_row or squeeze_col or row_col_are_absorbed):
-> 1705 res = delazify(res)
1706 if squeeze_row:
1707 res = res.squeeze(-2)
/usr/local/lib/python3.6/dist-packages/gpytorch/lazy/lazy_tensor.py in delazify(obj)
1753 return obj
1754 elif isinstance(obj, LazyTensor):
-> 1755 return obj.evaluate()
1756 else:
1757 raise TypeError("object of class {} cannot be made into a Tensor".format(obj.__class__.__name__))
/usr/local/lib/python3.6/dist-packages/gpytorch/utils/memoize.py in g(self, *args, **kwargs)
32 cache_name = name if name is not None else method
33 if not is_in_cache(self, cache_name):
---> 34 add_to_cache(self, cache_name, method(self, *args, **kwargs))
35 return get_from_cache(self, cache_name)
36
/usr/local/lib/python3.6/dist-packages/gpytorch/lazy/lazy_tensor.py in evaluate(self)
858 eye = torch.eye(num_rows, dtype=self.dtype, device=self.device)
859 eye = eye.expand(*self.batch_shape, num_rows, num_rows)
--> 860 res = self.transpose(-1, -2).matmul(eye).transpose(-1, -2).contiguous()
861 else:
862 eye = torch.eye(num_cols, dtype=self.dtype, device=self.device)
/usr/local/lib/python3.6/dist-packages/gpytorch/lazy/lazy_tensor.py in matmul(self, other)
1093
1094 func = Matmul()
-> 1095 return func.apply(self.representation_tree(), other, *self.representation())
1096
1097 @property
/usr/local/lib/python3.6/dist-packages/gpytorch/functions/_matmul.py in forward(ctx, representation_tree, rhs, *matrix_args)
18
19 lazy_tsr = ctx.representation_tree(*matrix_args)
---> 20 res = lazy_tsr._matmul(rhs)
21
22 to_save = [orig_rhs] + list(matrix_args)
/usr/local/lib/python3.6/dist-packages/gpytorch/lazy/block_lazy_tensor.py in _matmul(self, rhs)
64
65 rhs = self._add_batch_dim(rhs)
---> 66 res = self.base_lazy_tensor._matmul(rhs)
67 res = self._remove_batch_dim(res)
68
/usr/local/lib/python3.6/dist-packages/gpytorch/lazy/interpolated_lazy_tensor.py in _matmul(self, rhs)
157 def _matmul(self, rhs):
158 # Get sparse tensor representations of left/right interp matrices
--> 159 left_interp_t = self._sparse_left_interp_t(self.left_interp_indices, self.left_interp_values)
160 right_interp_t = self._sparse_right_interp_t(self.right_interp_indices, self.right_interp_values)
161
/usr/local/lib/python3.6/dist-packages/gpytorch/lazy/interpolated_lazy_tensor.py in _sparse_left_interp_t(self, left_interp_indices_tensor, left_interp_values_tensor)
309
310 left_interp_t = sparse.make_sparse_from_indices_and_values(
--> 311 left_interp_indices_tensor, left_interp_values_tensor, self.base_lazy_tensor.size()[-1]
312 )
313 self._left_interp_indices_memo = left_interp_indices_tensor
/usr/local/lib/python3.6/dist-packages/gpytorch/utils/sparse.py in make_sparse_from_indices_and_values(interp_indices, interp_values, num_rows)
59 else:
60 cls = getattr(torch.sparse, type_name)
---> 61 res = cls(index_tensor, value_tensor, interp_size)
62
63 # Wrap things as a variable, if necessary
RuntimeError: size is inconsistent with indices: for dim 1, size is 8 but found index 26
Expected Behavior
Expected to pass the tests.
System information
Please complete the following information:
- GPyTorch Version 0.3.5
- PyTorch Version 1.2.0
- Ubuntu 18.04.3 LTS
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (1 by maintainers)
Top Results From Across the Web
RuntimeError: sizes is inconsistent with indices. Pytorch 0.4.1
When using pytorch 0.4.1 I constantly get RunTimeError when creating sparse FloatTensors: num_nodes = 200 num_edges = 390 n2n_idxes = torch.
Read more >RuntimeError: inconsistent tensor sizes at /pytorch/torch/lib/TH ...
The error appears when PyTorch tries to stack together the images into a single batch tensor (cf. torch.stack([torch.from_numpy(b) for b in ...
Read more >Indexing tensors
Note that if one indexes a multidimensional tensor with fewer indices than dimensions, one gets an error, unlike in R that would flatten...
Read more >Understanding indexing with pytorch gather - Medium
When using [] operator, you select same index in every place. Consider 4x6 tensor (4 is for batch size, 6 is for features)....
Read more >tf.gather | TensorFlow v2.11.0
Supports negative indexes. batch_dims, An integer . The number of batch dimensions. Must be less than or equal to rank(indices) ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@panaali I’ll take a look!
Thanks @jacobrgardner It should work with your commit. In case anyone needed for the current version this is working: