Support partially dynamic shapes
See original GitHub issueIn update 5 we wrote:
Unfortunately, the problem of dynamic shapes is more complex than one might think. Enabling torchdynamo.config.dynamic_shapes will cause new graph breaks. Many models have code like assert x.shape == (1,2,3)
, if x.size(1) == 10
, math.sqrt(x.shape[-1])
, etc. This Python code operating on integer shapes is the defacto way to express many things in PyTorch. With static shapes, TorchDynamo can constant-propagate this stuff away, however, with dynamic shapes it will break the graph.
My current thinking is a “partially specialized shapes” mode in TorchDynamo. The basic idea would be that all shape start as fully dynamic, but then TorchDynamo would convert a tensor’s shapes to be static when the user called Tensor.size() and passed the result to a non-PyTorch operation. This would allow dynamic shapes most of the time, but still allow bigger graphs when users operate directly on shapes as integers.
To implement an initial version of this:
First build the analysis to add a TensorVariable().input_sources: Set[Source]
.
def foo(a, b):
c = a + b
In this example:
- a.input_souces = {a.source}
- b.input_souces = {b.source}
- c.input_souces = {a.source, b.source}
This is just a straight forward data flow analysis where sources are combined. It looks similar to the shape propagation currently implemented in TensorVariable.create.
Next, split GuardBuilder.TENSOR_MATCH into TENSOR_MATCH_STATIC
and TENSOR_MATCH_DYNAMIC
. The underlying TensorGuards object implemented in C++ already has these two modes, so it just requires having the generated code have two instances of that object.
Finally, modify how TensorVariable handles shape specialization. Defer setting TensorVariable().size and TensorVariable().stride until the user calls Tensor.size(). Note there are a few different ways to get the size, so search for usages of TensorVariable.size
.
When .size is called, add a new guard for TENSOR_MATCH_STATIC on all the input_sources. (You can remove the now redundant TENSOR_MATCH_DYNAMIC guard in guard codegen.)
This should give you something that works and passes tests.
Improvements initial prototype:
- We need to handle dynamic shape ops like
nonzero
,where
,repeat
, etc. Modify the analysis to mark tensors flowing from these ops, and break the graph if the user calls size on them. You can search for config.dynamic_shapes to find where we currently conditionally break the graph on those ops. - If a user passes the size directly to another PyTorch op, for example
torch.empty(x.size())
we don’t need to shape specialize and can just put the call to .size() in the graph. Similarly, simple math ops on sizes can be included in the graph. To handle this we will need a SizeVariable() to track and decide what can go in the graph and what requires specialization. - We don’t need to specialize every dimension if the user code only uses some dimensions. We need better shape analysis to make this happen though. @eellison might be able to provide pointers for better shape analysis.
cc @ezyang
Issue Analytics
- State:
- Created 2 years ago
- Comments:14 (11 by maintainers)
I like that idea, and pretty close to what I suggested – with the addition of adding a backoff-to-graph-breaking heuristic on the N-th recompile (where N=2).
Bailouts wouldn’t be that hard at the TorchDynamo level, however compiler backends don’t support them. Bailouts would also help in the case of data-dependent control flow, and remove the need for a graph break in those cases. Tracing-JIT style. @csarofeen any interest in adding bailout support to nvFuser.
The ordering of the decisions can also be flipped around giving something like:
A dimension value never affects python execution -> no need to specialize
A dimensions does affect python execution:
It is observed as static -> specialize
it is observed as changing -> graph break
However, primitive decomposition will add additional places where we have to make a size-based decision in rare cases. Also if we know a static size, compilers will have to do less work when further specializing for sizes (e.g. in tuning a matmul). So there is some value in ‘assume static’ until proven otherwise going first.