question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax2tf polymorphic_shapes with structured input

See original GitHub issue

I’m trying to use polymorphic shapes in jax2tf. From the docs it seems that structured inputs are supported:

It should be a Python object with the same pytree structure as, or a prefix of, the tuple of arguments to the function, but with a shape specification corresponding to each argument.

yet when I was trying to use it with a NamedTuple I got the error: InconclusiveDimensionOperation: Dimension polynomial 'b' is not constant

Did I do something wrong? Is this supported?

from jax.experimental import jax2tf
import tensorflow.compat.v1 as tf
from typing import Tuple, NamedTuple

class G(NamedTuple):
  a: Tuple[int]
  b: Tuple[int]

def fn(a):
  return a.a[0] + a.b[0]

g = jax2tf.convert(fn, polymorphic_shapes=["(b, ...)"])
h = jax2tf.convert(fn, polymorphic_shapes=[G(a="(b, ...)", b="(b, ...)")])
l = jax2tf.convert(fn, polymorphic_shapes=[G(a="(b, _)", b="(b, _)")])
m = jax2tf.convert(fn, polymorphic_shapes=["(b, _)"])
place = G(a=tf.placeholder(tf.int32, shape=(10, 5)), b=tf.placeholder(tf.int32, shape=(10, 5)))

g(place), h(place), ...

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:23

github_iconTop GitHub Comments

1reaction
gneculacommented, Jul 26, 2021

I was able to reproduce your use case, and indeed PR #7317 should solve this issue.

I think that you should be able to patch this PR locally if you clone the JAX sources, then patch the PR and then run pip install -e . in the JAX directory.

1reaction
gneculacommented, Jul 26, 2021

I asked “why do you think your code is shape polymorphic” because I did not not remember that your function is the result of jax.vmap. In that case we aim to extend JAX to recognise that shape polymorphism. It is possible that we are not there yet, as you seem to be encountering.

From the stack trace, it seems that the problem is in the invocation of lax.dynamic_slice (line 888). Can you get the shapes of the arguments of lax.dynamic_slice? I can try in parallel to construct a repro example.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax/jax2tf.py at main · google/jax - experimental - GitHub
polymorphic_shapes : Specifies input shapes to be treated polymorphically. during lowering. .. warning:: The shape-polymorphic lowering is an experimental ...
Read more >
jax.core.trace_state_clean Example - Program Talk
View Source File : jax2tf.py ... thereof (pytrees). polymorphic_shapes: Specifies input shapes to be treated polymorphically during conversion. .. warning:: ...
Read more >
jax._src.api - JAX documentation - Read the Docs
Coerce input donate_argnums = _ensure_index_tuple(donate_argnums) try: sig ... Do not drop unused inputs when we have # shape polymorphism, to ensure that ...
Read more >
differentiate, vectorize, JIT to GPU/TPU, and more - PythonRepo
If we wanted to apply this function to a batch of inputs at once, ... [ ] for dynamic (rather than just polymorphic)...
Read more >
Google Jax Statistics & Issues - Codesti
google/ldif: 3D Shape Representation with Local Deep Implicit Functions. ... ConcretizationTypeError using jax2tf with jnp.roll and polymorphic shapes.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found