jax2tf polymorphic_shapes with structured input
See original GitHub issueI’m trying to use polymorphic shapes in jax2tf. From the docs it seems that structured inputs are supported:
It should be a Python object with the same pytree structure as, or a prefix of, the tuple of arguments to the function, but with a shape specification corresponding to each argument.
yet when I was trying to use it with a NamedTuple I got the error:
InconclusiveDimensionOperation: Dimension polynomial 'b' is not constant
Did I do something wrong? Is this supported?
from jax.experimental import jax2tf
import tensorflow.compat.v1 as tf
from typing import Tuple, NamedTuple
class G(NamedTuple):
a: Tuple[int]
b: Tuple[int]
def fn(a):
return a.a[0] + a.b[0]
g = jax2tf.convert(fn, polymorphic_shapes=["(b, ...)"])
h = jax2tf.convert(fn, polymorphic_shapes=[G(a="(b, ...)", b="(b, ...)")])
l = jax2tf.convert(fn, polymorphic_shapes=[G(a="(b, _)", b="(b, _)")])
m = jax2tf.convert(fn, polymorphic_shapes=["(b, _)"])
place = G(a=tf.placeholder(tf.int32, shape=(10, 5)), b=tf.placeholder(tf.int32, shape=(10, 5)))
g(place), h(place), ...
Issue Analytics
- State:
- Created 2 years ago
- Comments:23
Top Results From Across the Web
jax/jax2tf.py at main · google/jax - experimental - GitHub
polymorphic_shapes : Specifies input shapes to be treated polymorphically. during lowering. .. warning:: The shape-polymorphic lowering is an experimental ...
Read more >jax.core.trace_state_clean Example - Program Talk
View Source File : jax2tf.py ... thereof (pytrees). polymorphic_shapes: Specifies input shapes to be treated polymorphically during conversion. .. warning:: ...
Read more >jax._src.api - JAX documentation - Read the Docs
Coerce input donate_argnums = _ensure_index_tuple(donate_argnums) try: sig ... Do not drop unused inputs when we have # shape polymorphism, to ensure that ...
Read more >differentiate, vectorize, JIT to GPU/TPU, and more - PythonRepo
If we wanted to apply this function to a batch of inputs at once, ... [ ] for dynamic (rather than just polymorphic)...
Read more >Google Jax Statistics & Issues - Codesti
google/ldif: 3D Shape Representation with Local Deep Implicit Functions. ... ConcretizationTypeError using jax2tf with jnp.roll and polymorphic shapes.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I was able to reproduce your use case, and indeed PR #7317 should solve this issue.
I think that you should be able to patch this PR locally if you clone the JAX sources, then patch the PR and then run
pip install -e .
in the JAX directory.I asked “why do you think your code is shape polymorphic” because I did not not remember that your function is the result of
jax.vmap
. In that case we aim to extend JAX to recognise that shape polymorphism. It is possible that we are not there yet, as you seem to be encountering.From the stack trace, it seems that the problem is in the invocation of
lax.dynamic_slice
(line 888). Can you get the shapes of the arguments oflax.dynamic_slice
? I can try in parallel to construct a repro example.