Incorrect dtypes error with the Mega model
See original GitHub issueHey, I’m seeing the following error when passing the ‘–mega’ option to use the mega model
(min-dalle) ➜ min-dalle git:(main) ✗ python image_from_text.py --text="a comfy chair that looks like an avocado" --mega
/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lib/__init__.py:34: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Namespace(image_path='generated', image_token_count=256, mega=True, seed=0, text='a comfy chair that looks like an avocado', torch=False)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
File "image_from_text.py", line 44, in <module>
image = generate_image_from_text(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
image_tokens[...] = generate_image_tokens_flax(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
image_tokens = decode_flax(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
image_tokens = decoder.sample_image_tokens(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
_, image_tokens = lax.scan(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
logits, keys_state, values_state = self.apply(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1159, in apply
return apply(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/scope.py", line 831, in wrapper
y = fn(root, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1535, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
decoder_state, (keys_state, values_state) = self.layers(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 218, in wrapper
y, out_variable_groups_xs_t = fn(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 770, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 754, in scanned
c, y = fn(scope, c, *args)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 307, in core_fn
res = fn(cloned, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
decoder_state, keys_values_state = self.self_attn(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
keys_state = lax.dynamic_update_slice(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
return dynamic_update_slice_p.bind(operand, update, *start_indices)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
out_aval, effects = primitive.abstract_eval(*avals, **params)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 359, in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
dtype_rule(*avals, **kwargs), weak_type=weak_type,
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "image_from_text.py", line 44, in <module>
image = generate_image_from_text(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
image_tokens[...] = generate_image_tokens_flax(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
image_tokens = decode_flax(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
image_tokens = decoder.sample_image_tokens(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
_, image_tokens = lax.scan(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
logits, keys_state, values_state = self.apply(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
decoder_state, (keys_state, values_state) = self.layers(
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
decoder_state, keys_values_state = self.self_attn(
File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
Issue Analytics
- State:
- Created a year ago
- Reactions:26
- Comments:6 (1 by maintainers)
Top Results From Across the Web
Incorrect dtype when read_csv reads a file big enough #1165
The problem is with the type of column. It should be category but on Ray or Dask it appears as object which doesn't...
Read more >Columns have mixed types. Specify dtype option on import or ...
In this blog post I give three solutions to solve the "DtypeWarning: Columns have mixed types" warning message.
Read more >XGBOOST Error Running Model - DataFrame.dtypes for data ...
I am trying to run a model which uses XGBOOST algorithm. However, upon execution there are couple of rows which has Nonetype data...
Read more >How to Fix: KeyError in Pandas - GeeksforGeeks
Pandas KeyError occurs when we try to access some column/row label in our DataFrame that doesn't exist. Usually, this error occurs when you ......
Read more >Learn Pandas with Pokemons - Kaggle
Explore and run machine learning code with Kaggle Notebooks | Using data from Pokemon with stats.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Changing the
requirements.txt
to haveflax==0.4.2
seems to fix this issue.Edit: Running on a M1 Pro in a python 3.9.13 venv.
Just a note, I had the same issue, and I resolved by doing the following
Then generate the image using torch
–
Edit: toabi’s solution above also works for me without the
--torch
flag. Though it is very very slow for me, as it does not seem to be able find my GPU without the--torch
flag.Ahhh, I was able to get the GPU working with jax via following their README. https://github.com/google/jax#pip-installation-gpu-cuda Turns out I did not have cuDNN installed. Though I then ran into the error:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 201326592 bytes.
Guessing my computer is not beefy enough or such.