question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[jax2tf] NotImplementedError: Call to gather cannot be converted with enable_xla=False

See original GitHub issue

I see that there is only partial support for the XlaGather op conversion when using enabled_xla=False (needed in my case because I want to convert the saved model to tflite/tfjs). I’m wondering if more support for this op is on the roadmap?

Here’s the full error that I’m seeing:

NotImplementedError
Traceback (most recent call last) [/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/impl_no_xla.py](https://localhost:8080/#) in _gather_using_tf_gather(operand, start_indices, dimension_numbers, slice_sizes, _in_avals)
    568     raise _xla_disabled_error(
    569         "gather",
--> 570         f"unsupported dimension_numbers '{dimension_numbers}'; op_shape={op_shape}."
    571     )
    572   # We added a trailing dimension of size 1

NotImplementedError: Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))'; op_shape=(1, 30, 512).

Below is some code that reproduces this error. I’ve only just started playing around with jax2tf so apologies if I’m doing something silly here.

!pip install --upgrade flax transformers
import jax
from jax.experimental import jax2tf
from jax import numpy as jnp

import numpy as np
import tensorflow as tf

from transformers import FlaxCLIPModel

clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def score(pixel_values, input_ids, attention_mask):
    pixel_values = jax.image.resize(pixel_values, (3, 224, 224), "nearest")
    inputs = {"pixel_values":jnp.array([pixel_values]), "input_ids":input_ids, "attention_mask":attention_mask}
    outputs = clip(**inputs)
    return outputs.logits_per_image[0][0]

score_tf = jax2tf.convert(score, enable_xla=False)

my_model = tf.Module()
my_model.f = tf.function(score_tf, autograph=False, jit_compile=True, input_signature=[
  tf.TensorSpec([3, 40, 40], tf.float32),
  tf.TensorSpec([1, 30], tf.int32),
  tf.TensorSpec([1, 30], tf.int32),
])

model_name = 'pixel_text_score'
tf.saved_model.save(my_model, model_name, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

Here’s a public colab with that code: https://colab.research.google.com/drive/18jFruauFcKEJ_SjBZqn5z6AiPqQAfa1j?usp=sharing

Thanks!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
gneculacommented, Feb 22, 2022

@marcvanzee Is there a way to describe what uses of XlaGather are supported? We could then add this to an FAQ and include a link in the error message.

1reaction
marcvanzeecommented, Feb 22, 2022

Thanks for the feedback! Please do let us know when you manage to get any models running in the browser using jax2tf, we will gladly link to your examples from our README. Currently this project is still in experimental phase, so we can learn / benefit a lot from early users like you. So thank you back for experimenting with this! 😄

Read more comments on GitHub >

github_iconTop Results From Across the Web

[jax2tf] NotImplementedError: Call to scatter add cannot be ...
grad(f)) doesn't due to lack of scatter add conversion. So my idea was to do jax2tf.convert(f) , and then get the gradient using...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found