question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

torch._dynamo.allow_in_graph does not work for python built-in operators #88708

See original GitHub issue

🐛 Describe the bug

torch._dynamo.allow_in_graph should allow operator.lt in the FX graph in any optimize configuration and prevent a break.

import operator
import torch
import torch._dynamo.testing
from typing import List


class CompileCounterWithBackend:
    def __init__(self, backend, fw_compiler=None, bw_compiler=None):
        self.frame_count = 0
        self.op_count = 0
        self.backend = backend
        self.fw_compiler = fw_compiler
        self.bw_compiler = bw_compiler

    def __call__(self, gm: torch.fx.GraphModule, example_inputs):
        from torch._dynamo.eval_frame import lookup_backend

        self.frame_count += 1
        for node in gm.graph.nodes:
            if "call" in node.op:
                self.op_count += 1
        if self.backend == "aot_autograd":
            return lookup_backend(self.backend)(
                gm,
                example_inputs,
                fw_compiler=self.fw_compiler,
                bw_compiler=self.bw_compiler,
            )
        return lookup_backend(self.backend)(gm, example_inputs)


def trace_printer(gm, _):
    # print(f"{'*'*128}\nFX Graph as Readable:\n{gm.print_readable()}")
    print(f"{'*'*128}\n FX Graph as Tabular\n{'*'*128}")
    print(gm.graph.print_tabular())
    return gm


def test_allow_in_graph_with_operator_lt():
    class MyModule(torch.nn.Module):
        def forward(self, a):
            x = torch.add(a, 1)
            y = torch.add(x, 1)
            if x.sum() < 0:
                x += torch.add(x, 1)
                y += torch.add(x, 1)
            return x + y

    nopython = False
    data = torch.randn(10)
    model = MyModule()
    torch._dynamo.allow_in_graph(operator.lt)
    compile_counter = CompileCounterWithBackend(
        "aot_autograd", fw_compiler=trace_printer
    )
    dynamo_model = torch._dynamo.optimize(compile_counter, nopython=nopython)(model)
    dynamo_model(data)
    torch._dynamo.disallow_in_graph(operator.lt)

    assert (
        compile_counter.frame_count == 1
    ), f"{compile_counter.frame_count} graph breaks were found!"


test_allow_in_graph_with_operator_lt()

In the above example, the assert fail due to 2 breaks instead of 1

Versions

Collecting environment information...
PyTorch version: 1.14.0a0+gita8f40b3
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 10.4.0-4ubuntu1~22.04) 10.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.13 (main, Aug 25 2022, 23:26:10)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.3.109
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA TITAN V
GPU 1: NVIDIA TITAN V

Nvidia driver version: 520.61.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.960
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.6
[pip3] torch==1.14.0a0+gita8f40b3
[pip3] torchvision==0.14.0a0+3c9ae0a
[conda] blas                      1.0                         mkl  
[conda] magma-cuda113             2.5.2                         1    pytorch
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-include               2022.1.0           h06a4308_224  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.21.6                   pypi_0    pypi
[conda] torch                     1.14.0a0+gita8f40b3           dev_0    <develop>
[conda] torchvision               0.14.0a0+3c9ae0a           dev_0    <develop>

Error logs

There is no actual error log, just an unexpected behavior which ends execution without raising any exception Below is the script output:

********************************************************************************************************************************
FX Graph as Tabular:

opcode         name    target            args          kwargs
-------------  ------  ----------------  ------------  --------
placeholder    arg0_1  arg0_1            ()            {}
call_function  add     aten.add.Tensor   (arg0_1, 33)  {}
call_function  sum_1   aten.sum.default  (add,)        {}
call_function  lt      aten.lt.Scalar    (sum_1, 0)    {}
output         output  output            ([add, lt],)  {}

................................................................................................................................
/home/thiagofc/dev/github/pytorch-dev1/functorch/_src/aot_autograd.py:350: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
********************************************************************************************************************************
FX Graph as Tabular:

opcode         name    target           args              kwargs
-------------  ------  ---------------  ----------------  --------
placeholder    arg0_1  arg0_1           ()                {}
placeholder    arg1_1  arg1_1           ()                {}
call_function  add     aten.add.Tensor  (arg1_1, arg0_1)  {}
output         output  output           ([add],)          {}

................................................................................................................................
/home/thiagofc/dev/github/pytorch-dev1/functorch/_src/aot_autograd.py:350: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(

Minified repro

There is no actual error, the graph is broken due to an allowed op

Issue Analytics

  • State:closed
  • Created 10 months ago
  • Comments:9 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
ngimelcommented, Nov 21, 2022

What are you trying to achieve? There’s cond operator that supports data-dependent control flow with some limitations https://github.com/pytorch/pytorch/pull/83154, is that what you are looking for? Just forcing something in the graph with allow_in_graph doesn’t solve soundness problems.

1reaction
eellisoncommented, Nov 21, 2022

I think we would accept a PR to make the explanation a bit better. Something like reason='generic_jump TensorVariable()' -> reason='generic_jump TensorVariable() is not constant'. wdyt @voznesenskym ?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Built-in Types — Python 3.11.1 documentation
The following sections describe the standard types that are built into the interpreter. The principal built-in types are numerics, sequences, mappings, ...
Read more >
TorchDynamo: An Experiment in Dynamic Python Bytecode ...
TorchDynamo dynamically rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with...
Read more >
Operator and Function Overloading in Custom Python Classes
If the behavior of a built-in function or operator is not defined in the class by the special method, then you will get...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found