question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Use `torch.__torch_function__` to implement LazyTensor API

See original GitHub issue

We should consider using the __torch_function__ machinery as described in https://pytorch.org/docs/stable/notes/extending.html to make LazyTensors compatible with regular tensors.

This would have a number of benefits:

  • much less cumbersome to mix LazyTensor and Tensor operations
  • would get rid of NonLazyTensor
  • would allow us to recycle a bunch of torch code rather than re-implementing parts of it (like we do for the MVNs).

Here is a minimal example of how this could work for the DiagLazyTensor:

import torch 
from torch import Tensor
import functools


HANDLED_FUNCTIONS = {}
COMPATIBLE_TYPES = {}


def implements(torch_function, types=None):
    """Register a torch function override for DiagLazyTensor"""
    @functools.wraps(torch_function)
    def decorator(func):
        HANDLED_FUNCTIONS[torch_function] = func
        COMPATIBLE_TYPES[torch_function] = types
        return func
    return decorator


class DiagLazyTensor(Tensor):
    def __init__(self, diag):
        self.diag = diag

    def __repr__(self):
        return "{n}x{n} DiagonalTensor".format(n=self.diag.shape[-1])

    def tensor(self):
        return torch.diag_embed(self.diag)

    @implements(torch.cholesky)
    def _cholesky(self, upper=False):
        return DiagLazyTensor(torch.sqrt(self.diag))

    @implements(torch.add)
    def _add(self, other):
        if isinstance(other, DiagLazyTensor):
            return DiagLazyTensor(self.diag + other.diag)
        return other + self.tensor()

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if (func not in HANDLED_FUNCTIONS):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

With this we can do

d = DiagLazyTensor(torch.rand(3))
torch.cholesky(d)

or

torch.add(d, torch.rand(3, 3))  # results in Tensor
torch.add(d, DiagLazyTensor(torch.rand(3)))  # results in DiagLazyTensor

If we add the following in some abstract LazyTensor base class:

    def __add__(self, other):
        return torch.add(self, other)
    
    def __radd__(self, other):
        return torch.add(self, other)

then we can just exploit the above implementations and things like

d + torch.rand(3, 3)
torch.rand(3, 3) + d

will work out of the box.

I haven’t really thought much about the downsides of doing this (rather than it being work), but there are probably some. Curious to hear what people think.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:2
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
Balandatcommented, Nov 10, 2020

Yeah that seems like a good design pattern. I can see if I can get something going here.

Since this is a larger refactor and will subclass Tensor, I wonder if it makes sense to introduce a PSDLazyTensor while hacking on this - that way we could cleanly confine the implicit PSD assumption that we’re making to PSDLazyTensor and have things like TriangularLazyTensor or possible other future tensors without worrying about breaking things left and right in unforeseen ways.

2reactions
jacobrgardnercommented, Nov 10, 2020

@Balandat I can’t really think of any downsides of allowing for this, it’s pretty much the obvious way LazyTensors should be implemented given that PyTorch went to the trouble to add this as a possibility.

I think a good pattern for getting this done efficiently might be to push the @implements decorators up to the LazyTensor class itself on public functions (like add), and then have DiagLT override private versions like _add. We already kind of do this for a ton of behavior – _getitem, _solve, _cholesky, etc etc.

If we did this, I think we might be able to adapt LTs pretty easily by basically pushing all of the __torch_function__ logic up to the base class, and then have child LTs define or override behavior as they currently do.

Read more comments on GitHub >

github_iconTop Results From Across the Web

LazyTensor — KeOps
This section contains the full API documentation of the LazyTensor wrapper, which works identically on NumPy arrays and PyTorch tensors.
Read more >
Extending PyTorch — PyTorch 1.13 documentation
Extending PyTorch. In this note we'll cover ways of extending torch.nn , torch.autograd , torch , and writing custom C extensions utilizing our...
Read more >
ONNX in a torch function — deeponnxcustom - Xavier Dupré
The ONNX graph used in this example is not really interesting but it shows how to create a custom autograd function following torch...
Read more >
Lazy Tensor Core Documentation Out-of-Date #69033 - GitHub
The backend APIs are already landed in a form that you could take a look at, but they aren't guaranteed stable and may...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found