question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[jax2tf] NotImplementedError: Call to scatter add cannot be converted with enable_xla=False

See original GitHub issue

Conversions for XlaScatter are currently unsupported when using enabled_xla=False. I’m wondering if support could be added?

Here’s the full error that I’m seeing:

NotImplementedError                       Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/impl_no_xla.py in op(*arg, **kwargs)
     52 
     53   def op(*arg, **kwargs):
---> 54     raise _xla_disabled_error(name)
     55 
     56   return op

NotImplementedError: Call to scatter add cannot be converted with enable_xla=False.

Here is some code that reproduces this error:

!pip install --upgrade flax
!pip install git+https://github.com/josephrocca/transformers.git@patch-2
import jax
from jax.experimental import jax2tf
from jax import numpy as jnp

import numpy as np
import tensorflow as tf

from transformers import FlaxCLIPModel

clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def score(pixel_values, input_ids, attention_mask):
    pixel_values = jax.image.resize(pixel_values, (3, 224, 224), "nearest")
    inputs = {"pixel_values":jnp.array([pixel_values]), "input_ids":input_ids, "attention_mask":attention_mask}
    outputs = clip(**inputs)
    return outputs.logits_per_image[0][0][0]

score_tf = jax2tf.convert(jax.grad(score), enable_xla=False)

my_model = tf.Module()
my_model.f = tf.function(score_tf, autograph=False, jit_compile=True, input_signature=[
  tf.TensorSpec([3, 40, 40], tf.float32),
  tf.TensorSpec([1, 30], tf.int32),
  tf.TensorSpec([1, 30], tf.int32),
])

model_name = 'pixel_text_score_grad'
tf.saved_model.save(my_model, model_name, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

Here’s a public colab with that code: https://colab.research.google.com/drive/1HjpRXsa8Ue9KWiKbVVWUX6DlXoIYx2r8?usp=sharing You can click “Runtime > Run all” to see the error.

Thanks!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:35 (24 by maintainers)

github_iconTop GitHub Comments

2reactions
gneculacommented, May 17, 2022

One advantage for now for the jax2tf + TFLite converter is that it is much more extensively used and tested than the tflite.experimental_from_jax. There is significant usage for jax2tf for serving, and the tflite converter is also used a lot. The corner cases when this path does not work are likely to be a subset of those when the experimental_from_jax does not work either. This is because the difficulties in both cases are to cover all the corner cases for complex operations such as gather, scatter, and convolution.

For the long term, it is quite possible that the direct HLO to TFLite path will evolve to be the simpler path.

1reaction
marcvanzeecommented, May 25, 2022

Another update: I think I was actually wrong in saying that you should always convert JAX --> TFLite through from_concrete_function. I think the main use case for going through a SavedModel is to be able to support shape polymorphism. This doesn’t seem to work when using from_concrete_function because the signature looks quite odd, and I think it only supports a single shape. Moreover, TFLite actually recommends the SavedModel path themselves as well here!. Our current MNIST example does not show how to go from a JAX model to a TFLite model with shape polymorphism.

I have filed another issue saying we should improve this (#10821).

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found