torch._dynamo.allow_in_graph does not work for python built-in operators #88708
See original GitHub issue🐛 Describe the bug
torch._dynamo.allow_in_graph
should allow operator.lt
in the FX graph in any optimize
configuration and prevent a break.
import operator
import torch
import torch._dynamo.testing
from typing import List
class CompileCounterWithBackend:
def __init__(self, backend, fw_compiler=None, bw_compiler=None):
self.frame_count = 0
self.op_count = 0
self.backend = backend
self.fw_compiler = fw_compiler
self.bw_compiler = bw_compiler
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
from torch._dynamo.eval_frame import lookup_backend
self.frame_count += 1
for node in gm.graph.nodes:
if "call" in node.op:
self.op_count += 1
if self.backend == "aot_autograd":
return lookup_backend(self.backend)(
gm,
example_inputs,
fw_compiler=self.fw_compiler,
bw_compiler=self.bw_compiler,
)
return lookup_backend(self.backend)(gm, example_inputs)
def trace_printer(gm, _):
# print(f"{'*'*128}\nFX Graph as Readable:\n{gm.print_readable()}")
print(f"{'*'*128}\n FX Graph as Tabular\n{'*'*128}")
print(gm.graph.print_tabular())
return gm
def test_allow_in_graph_with_operator_lt():
class MyModule(torch.nn.Module):
def forward(self, a):
x = torch.add(a, 1)
y = torch.add(x, 1)
if x.sum() < 0:
x += torch.add(x, 1)
y += torch.add(x, 1)
return x + y
nopython = False
data = torch.randn(10)
model = MyModule()
torch._dynamo.allow_in_graph(operator.lt)
compile_counter = CompileCounterWithBackend(
"aot_autograd", fw_compiler=trace_printer
)
dynamo_model = torch._dynamo.optimize(compile_counter, nopython=nopython)(model)
dynamo_model(data)
torch._dynamo.disallow_in_graph(operator.lt)
assert (
compile_counter.frame_count == 1
), f"{compile_counter.frame_count} graph breaks were found!"
test_allow_in_graph_with_operator_lt()
In the above example, the assert fail due to 2 breaks instead of 1
Versions
Collecting environment information...
PyTorch version: 1.14.0a0+gita8f40b3
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 10.4.0-4ubuntu1~22.04) 10.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.3.109
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA TITAN V
GPU 1: NVIDIA TITAN V
Nvidia driver version: 520.61.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.6.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.6.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy==0.960
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.6
[pip3] torch==1.14.0a0+gita8f40b3
[pip3] torchvision==0.14.0a0+3c9ae0a
[conda] blas 1.0 mkl
[conda] magma-cuda113 2.5.2 1 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-include 2022.1.0 h06a4308_224
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.21.6 pypi_0 pypi
[conda] torch 1.14.0a0+gita8f40b3 dev_0 <develop>
[conda] torchvision 0.14.0a0+3c9ae0a dev_0 <develop>
Error logs
There is no actual error log, just an unexpected behavior which ends execution without raising any exception Below is the script output:
********************************************************************************************************************************
FX Graph as Tabular:
opcode name target args kwargs
------------- ------ ---------------- ------------ --------
placeholder arg0_1 arg0_1 () {}
call_function add aten.add.Tensor (arg0_1, 33) {}
call_function sum_1 aten.sum.default (add,) {}
call_function lt aten.lt.Scalar (sum_1, 0) {}
output output output ([add, lt],) {}
................................................................................................................................
/home/thiagofc/dev/github/pytorch-dev1/functorch/_src/aot_autograd.py:350: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
warnings.warn(
********************************************************************************************************************************
FX Graph as Tabular:
opcode name target args kwargs
------------- ------ --------------- ---------------- --------
placeholder arg0_1 arg0_1 () {}
placeholder arg1_1 arg1_1 () {}
call_function add aten.add.Tensor (arg1_1, arg0_1) {}
output output output ([add],) {}
................................................................................................................................
/home/thiagofc/dev/github/pytorch-dev1/functorch/_src/aot_autograd.py:350: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
warnings.warn(
Minified repro
There is no actual error, the graph is broken due to an allowed op
Issue Analytics
- State:
- Created 10 months ago
- Comments:9 (4 by maintainers)
Top Results From Across the Web
Built-in Types — Python 3.11.1 documentation
The following sections describe the standard types that are built into the interpreter. The principal built-in types are numerics, sequences, mappings, ...
Read more >TorchDynamo: An Experiment in Dynamic Python Bytecode ...
TorchDynamo dynamically rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with...
Read more >Operator and Function Overloading in Custom Python Classes
If the behavior of a built-in function or operator is not defined in the class by the special method, then you will get...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
What are you trying to achieve? There’s
cond
operator that supports data-dependent control flow with some limitations https://github.com/pytorch/pytorch/pull/83154, is that what you are looking for? Just forcing something in the graph withallow_in_graph
doesn’t solve soundness problems.I think we would accept a PR to make the explanation a bit better. Something like
reason='generic_jump TensorVariable()'
->reason='generic_jump TensorVariable() is not constant'
. wdyt @voznesenskym ?