--mega / is_mega raises TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
See original GitHub issueI tried running the following in the Google Colab:
image = generate_image_from_text("court sketch of godzilla on trial", is_mega=True, seed=100)
This caused an exception:
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġcourt']
['Ġsketch']
['Ġof']
['Ġgodzilla']
['Ġon']
['Ġtrial']
text tokens [0, 2634, 4189, 111, 14450, 133, 5167, 2]
loading flax encoder
encoding text tokens
loading flax decoder
sampling image tokens
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-5-53d46ed9885c>](https://localhost:8080/#) in <module>()
2
----> 3 image = generate_image_from_text("court sketch of godzilla on trial", is_mega=True, seed=100)
4 display(image)
67 frames
UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
[/content/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py](https://localhost:8080/#) in __call__(self, decoder_state, keys_state, values_state, attention_mask, state_index)
38 keys_state,
39 self.k_proj(decoder_state).reshape(shape_split),
---> 40 state_index
41 )
42 values_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
The same thing happened when I tried running the command-line locally:
python3 image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=100
NOTE: I had to add the following line to the Setup block of the Jupyter code:
! wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
Issue Analytics
- State:
- Created a year ago
- Reactions:2
- Comments:5 (2 by maintainers)
Top Results From Across the Web
jax/lax.py at main · google/jax - GitHub
new_dtype: a NumPy dtype representing the target type. Returns: ... raise ValueError('Arguments to batch_matmul must have same ndim, got {}, {}'.
Read more >Operation Semantics | XLA - TensorFlow
AfterAll takes a variadic number of tokens and produces a single token. Tokens are primitive types which can be threaded between side-effecting ...
Read more >Reimplementing bert-style pooler throws shape error as if ...
The mistake was not in the architecture. Problem was: My inputs were not shaped correctly. The target should have been of shape (batch_size,...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
It should work if you
pip install flax==0.4.2
. I need to address what is causing the dtype mismatch in the latest flax versionOk, it should work with the latest flax version now