Dynamo can not optimize a model with MaxPool2d on XLA devices
See original GitHub issue🐛 Describe the bug
Found this bug when integrate dynamo with torchxla for resnet model. If we move the model and its inputs to XLA device before running dynamo, we would hit this bug. Check the minimal repro below.
cc @jansel @wconstab @jackcaog
Error logs
File “/pytorch/torch/_dynamo/convert_frame.py”, line 118, in _fn return fn(*args, **kwargs) File “/pytorch/torch/_dynamo/utils.py”, line 92, in time_wrapper File “/pytorch/torch/_dynamo/convert_frame.py”, line 118, in _fn File “/pytorch/torch/_refs/init.py”, line 45, in <module> from torch.fx.experimental.symbolic_shapes import sym_float, sym_int File “/pytorch/torch/fx/experimental/symbolic_shapes.py”, line 17, in <module> import sympy # type: ignore[import] File “/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/init.py”, line 51, in <module> from .core import (sympify, SympifyError, cacheit, Basic, Atom, File “/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/init.py”, line 4, in <module> from .sympify import sympify, SympifyError File “/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/sympify.py”, line 9, in <module> from .compatibility import iterable File “/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/compatibility.py”, line 11, in <module> from sympy.external import import_module File “/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/init.py”, line 18, in <module> from sympy.external.importtools import import_module File “/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/importtools.py”, line 4, in <module> from distutils.version import LooseVersion File “<frozen importlib._bootstrap>”, line 983, in _find_and_load File “<frozen importlib._bootstrap>”, line 963, in _find_and_load_unlocked File “<frozen importlib._bootstrap>”, line 906, in _find_spec return fn(*args, **kwargs) File “/pytorch/torch/_dynamo/utils.py”, line 92, in time_wrapper r = func(*args, **kwargs) File “/pytorch/torch/_dynamo/convert_frame.py”, line 356, in _convert_frame_assert frame, File “/pytorch/torch/_dynamo/convert_frame.py”, line 402, in _compile out_code = transform_code_object(code, transform) File “/pytorch/torch/_dynamo/bytecode_transformation.py”, line 341, in transform_code_object transformations(instructions, code_options) File “/pytorch/torch/_dynamo/convert_frame.py”, line 390, in transform tracer.run() File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 1468, in run super().run() File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 352, in run and self.step() File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 322, in step getattr(self, inst.opname)(inst) File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 174, in wrapper return inner_fn(self, inst) File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 766, in CALL_FUNCTION self.call_function(fn, args, {}) File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 264, in call_function self.push(fn.call_function(self, args, kwargs)) File “/pytorch/torch/_dynamo/variables/nn_module.py”, line 209, in call_function **options, File “/pytorch/torch/_dynamo/convert_frame.py”, line 118, in _fn return fn(*args, **kwargs) File “/pytorch/torch/_dynamo/utils.py”, line 92, in time_wrapper r = func(*args, **kwargs) File “/pytorch/torch/_dynamo/convert_frame.py”, line 356, in _convert_frame_assert frame, File “/pytorch/torch/_dynamo/convert_frame.py”, line 402, in _compile out_code = transform_code_object(code, transform) File “/pytorch/torch/_dynamo/bytecode_transformation.py”, line 341, in transform_code_object transformations(instructions, code_options) File “/pytorch/torch/_dynamo/convert_frame.py”, line 390, in transform tracer.run() File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 1468, in run super().run() File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 352, in run and self.step() File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 322, in step getattr(self, inst.opname)(inst) File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 174, in wrapper return inner_fn(self, inst) File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 766, in CALL_FUNCTION self.call_function(fn, args, {}) File “/pytorch/torch/_dynamo/symbolic_convert.py”, line 264, in call_function self.push(fn.call_function(self, args, kwargs)) File “/pytorch/torch/_dynamo/variables/nn_module.py”, line 209, in call_function **options, File “/pytorch/torch/_dynamo/variables/tensor.py”, line 201, in create example_value = _get_fake_value(proxy.node, tx) File “/pytorch/torch/_dynamo/variables/tensor.py”, line 145, in _get_fake_value raise TorchRuntimeError() from e torch._dynamo.exc.TorchRuntimeError:
from user code: File “myscripts/repro_maxpool.py”, line 14, in forward out = self.pool(out)
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting: torch._dynamo.config.suppress_errors = True
Minified repro
repro_maxpool.py
from torch import nn
import torch
import torch._dynamo as dynamo
import torch_xla.core.xla_model as xm
class MaxPoolModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=2)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
def forward(self, x):
out = self.conv(x)
out = self.pool(out)
return out
def get_random_inputs(self):
return (torch.rand(2, 3, 10, 10),)
xla_dev = xm.xla_device()
model = MaxPoolModule().to(device=xla_dev)
inputs = map(lambda x: x.to(device=xla_dev), model.get_random_inputs())
dynamo.optimize(lambda gm, _: gm)(lambda: model(*inputs))()
Command:
GPU_NUM_DEVICES=1 python repro_maxpool.py
Issue Analytics
- State:
- Created a year ago
- Comments:26 (22 by maintainers)
Top GitHub Comments
Training for resnet18 works now with aot_eager or aot_torchxla_trivial backend if replacing MaxPool to AvgPool.
aot_torchxla_trace_once
still have some problem, but that’s not related to this issue. Overall, replace MaxPool to AvgPool can temporarily unblock the project before the real fix is in.Oh- I’m not at my laptop, but this should be fixable.
instead of redispatching directly to the python key, we need to make sure that we can hit any other dispatch keys that need to run first, like functionalization.