[jax2tf] NotImplementedError: Call to gather cannot be converted with enable_xla=False
See original GitHub issueI see that there is only partial support for the XlaGather
op conversion when using enabled_xla=False
(needed in my case because I want to convert the saved model to tflite/tfjs). I’m wondering if more support for this op is on the roadmap?
Here’s the full error that I’m seeing:
NotImplementedError
Traceback (most recent call last) [/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/impl_no_xla.py](https://localhost:8080/#) in _gather_using_tf_gather(operand, start_indices, dimension_numbers, slice_sizes, _in_avals)
568 raise _xla_disabled_error(
569 "gather",
--> 570 f"unsupported dimension_numbers '{dimension_numbers}'; op_shape={op_shape}."
571 )
572 # We added a trailing dimension of size 1
NotImplementedError: Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))'; op_shape=(1, 30, 512).
Below is some code that reproduces this error. I’ve only just started playing around with jax2tf
so apologies if I’m doing something silly here.
!pip install --upgrade flax transformers
import jax
from jax.experimental import jax2tf
from jax import numpy as jnp
import numpy as np
import tensorflow as tf
from transformers import FlaxCLIPModel
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
def score(pixel_values, input_ids, attention_mask):
pixel_values = jax.image.resize(pixel_values, (3, 224, 224), "nearest")
inputs = {"pixel_values":jnp.array([pixel_values]), "input_ids":input_ids, "attention_mask":attention_mask}
outputs = clip(**inputs)
return outputs.logits_per_image[0][0]
score_tf = jax2tf.convert(score, enable_xla=False)
my_model = tf.Module()
my_model.f = tf.function(score_tf, autograph=False, jit_compile=True, input_signature=[
tf.TensorSpec([3, 40, 40], tf.float32),
tf.TensorSpec([1, 30], tf.int32),
tf.TensorSpec([1, 30], tf.int32),
])
model_name = 'pixel_text_score'
tf.saved_model.save(my_model, model_name, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
Here’s a public colab with that code: https://colab.research.google.com/drive/18jFruauFcKEJ_SjBZqn5z6AiPqQAfa1j?usp=sharing
Thanks!
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (2 by maintainers)
Top Results From Across the Web
[jax2tf] NotImplementedError: Call to scatter add cannot be ...
grad(f)) doesn't due to lack of scatter add conversion. So my idea was to do jax2tf.convert(f) , and then get the gradient using...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@marcvanzee Is there a way to describe what uses of XlaGather are supported? We could then add this to an FAQ and include a link in the error message.
Thanks for the feedback! Please do let us know when you manage to get any models running in the browser using jax2tf, we will gladly link to your examples from our README. Currently this project is still in experimental phase, so we can learn / benefit a lot from early users like you. So thank you back for experimenting with this! 😄