Use `torch.__torch_function__` to implement LazyTensor API
See original GitHub issueWe should consider using the __torch_function__
machinery as described in https://pytorch.org/docs/stable/notes/extending.html to make LazyTensor
s compatible with regular tensors.
This would have a number of benefits:
- much less cumbersome to mix
LazyTensor
andTensor
operations - would get rid of
NonLazyTensor
- would allow us to recycle a bunch of torch code rather than re-implementing parts of it (like we do for the MVNs).
Here is a minimal example of how this could work for the DiagLazyTensor
:
import torch
from torch import Tensor
import functools
HANDLED_FUNCTIONS = {}
COMPATIBLE_TYPES = {}
def implements(torch_function, types=None):
"""Register a torch function override for DiagLazyTensor"""
@functools.wraps(torch_function)
def decorator(func):
HANDLED_FUNCTIONS[torch_function] = func
COMPATIBLE_TYPES[torch_function] = types
return func
return decorator
class DiagLazyTensor(Tensor):
def __init__(self, diag):
self.diag = diag
def __repr__(self):
return "{n}x{n} DiagonalTensor".format(n=self.diag.shape[-1])
def tensor(self):
return torch.diag_embed(self.diag)
@implements(torch.cholesky)
def _cholesky(self, upper=False):
return DiagLazyTensor(torch.sqrt(self.diag))
@implements(torch.add)
def _add(self, other):
if isinstance(other, DiagLazyTensor):
return DiagLazyTensor(self.diag + other.diag)
return other + self.tensor()
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if (func not in HANDLED_FUNCTIONS):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
With this we can do
d = DiagLazyTensor(torch.rand(3))
torch.cholesky(d)
or
torch.add(d, torch.rand(3, 3)) # results in Tensor
torch.add(d, DiagLazyTensor(torch.rand(3))) # results in DiagLazyTensor
If we add the following in some abstract LazyTensor
base class:
def __add__(self, other):
return torch.add(self, other)
def __radd__(self, other):
return torch.add(self, other)
then we can just exploit the above implementations and things like
d + torch.rand(3, 3)
torch.rand(3, 3) + d
will work out of the box.
I haven’t really thought much about the downsides of doing this (rather than it being work), but there are probably some. Curious to hear what people think.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:5 (3 by maintainers)
Top Results From Across the Web
LazyTensor — KeOps
This section contains the full API documentation of the LazyTensor wrapper, which works identically on NumPy arrays and PyTorch tensors.
Read more >Extending PyTorch — PyTorch 1.13 documentation
Extending PyTorch. In this note we'll cover ways of extending torch.nn , torch.autograd , torch , and writing custom C extensions utilizing our...
Read more >ONNX in a torch function — deeponnxcustom - Xavier Dupré
The ONNX graph used in this example is not really interesting but it shows how to create a custom autograd function following torch...
Read more >Lazy Tensor Core Documentation Out-of-Date #69033 - GitHub
The backend APIs are already landed in a form that you could take a look at, but they aren't guaranteed stable and may...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yeah that seems like a good design pattern. I can see if I can get something going here.
Since this is a larger refactor and will subclass
Tensor
, I wonder if it makes sense to introduce aPSDLazyTensor
while hacking on this - that way we could cleanly confine the implicit PSD assumption that we’re making toPSDLazyTensor
and have things likeTriangularLazyTensor
or possible other future tensors without worrying about breaking things left and right in unforeseen ways.@Balandat I can’t really think of any downsides of allowing for this, it’s pretty much the obvious way LazyTensors should be implemented given that PyTorch went to the trouble to add this as a possibility.
I think a good pattern for getting this done efficiently might be to push the
@implements
decorators up to theLazyTensor
class itself on public functions (likeadd
), and then haveDiagLT
override private versions like_add
. We already kind of do this for a ton of behavior –_getitem
,_solve
,_cholesky
, etc etc.If we did this, I think we might be able to adapt LTs pretty easily by basically pushing all of the
__torch_function__
logic up to the base class, and then have child LTs define or override behavior as they currently do.