[jax2tf] NotImplementedError: Call to scatter add cannot be converted with enable_xla=False
See original GitHub issueConversions for XlaScatter are currently unsupported when using enabled_xla=False
. I’m wondering if support could be added?
Here’s the full error that I’m seeing:
NotImplementedError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/impl_no_xla.py in op(*arg, **kwargs)
52
53 def op(*arg, **kwargs):
---> 54 raise _xla_disabled_error(name)
55
56 return op
NotImplementedError: Call to scatter add cannot be converted with enable_xla=False.
Here is some code that reproduces this error:
!pip install --upgrade flax
!pip install git+https://github.com/josephrocca/transformers.git@patch-2
import jax
from jax.experimental import jax2tf
from jax import numpy as jnp
import numpy as np
import tensorflow as tf
from transformers import FlaxCLIPModel
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
def score(pixel_values, input_ids, attention_mask):
pixel_values = jax.image.resize(pixel_values, (3, 224, 224), "nearest")
inputs = {"pixel_values":jnp.array([pixel_values]), "input_ids":input_ids, "attention_mask":attention_mask}
outputs = clip(**inputs)
return outputs.logits_per_image[0][0][0]
score_tf = jax2tf.convert(jax.grad(score), enable_xla=False)
my_model = tf.Module()
my_model.f = tf.function(score_tf, autograph=False, jit_compile=True, input_signature=[
tf.TensorSpec([3, 40, 40], tf.float32),
tf.TensorSpec([1, 30], tf.int32),
tf.TensorSpec([1, 30], tf.int32),
])
model_name = 'pixel_text_score_grad'
tf.saved_model.save(my_model, model_name, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
Here’s a public colab with that code: https://colab.research.google.com/drive/1HjpRXsa8Ue9KWiKbVVWUX6DlXoIYx2r8?usp=sharing You can click “Runtime > Run all” to see the error.
Thanks!
Issue Analytics
- State:
- Created 2 years ago
- Comments:35 (24 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
One advantage for now for the jax2tf + TFLite converter is that it is much more extensively used and tested than the
tflite.experimental_from_jax
. There is significant usage for jax2tf for serving, and the tflite converter is also used a lot. The corner cases when this path does not work are likely to be a subset of those when theexperimental_from_jax
does not work either. This is because the difficulties in both cases are to cover all the corner cases for complex operations such asgather
,scatter
, andconvolution
.For the long term, it is quite possible that the direct HLO to TFLite path will evolve to be the simpler path.
Another update: I think I was actually wrong in saying that you should always convert JAX --> TFLite through
from_concrete_function
. I think the main use case for going through aSavedModel
is to be able to support shape polymorphism. This doesn’t seem to work when usingfrom_concrete_function
because the signature looks quite odd, and I think it only supports a single shape. Moreover, TFLite actually recommends the SavedModel path themselves as well here!. Our current MNIST example does not show how to go from a JAX model to a TFLite model with shape polymorphism.I have filed another issue saying we should improve this (#10821).