question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Incorrect dtypes error with the Mega model

See original GitHub issue

Hey, I’m seeing the following error when passing the ‘–mega’ option to use the mega model

(min-dalle) ➜  min-dalle git:(main) ✗ python image_from_text.py --text="a comfy chair that looks like an avocado" --mega

/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lib/__init__.py:34: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
Namespace(image_path='generated', image_token_count=256, mega=True, seed=0, text='a comfy chair that looks like an avocado', torch=False)
parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġa']
['Ġcomfy']
['Ġchair']
['Ġthat']
['Ġlooks']
['Ġlike']
['Ġan']
['Ġavocado']
text tokens [0, 58, 29872, 2408, 766, 4126, 1572, 101, 16632, 2]
loading flax encoder
encoding text tokens
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
loading flax decoder
sampling image tokens
Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1498, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1484, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 82, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1159, in apply
    return apply(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/scope.py", line 831, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 1535, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 314, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 218, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 770, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/lift.py", line 754, in scanned
    c, y = fn(scope, c, *args)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 307, in core_fn
    res = fn(cloned, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 147, in dynamic_update_slice
    return dynamic_update_slice_p.bind(operand, update, *start_indices)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 166, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 177, in default_process_primitive
    out_aval, effects = primitive.abstract_eval(*avals, **params)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/core.py", line 359, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 933, in _dynamic_update_slice_dtype_rule
    lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype,
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4373, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "image_from_text.py", line 44, in <module>
    image = generate_image_from_text(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/generate_image.py", line 66, in generate_image_from_text
    image_tokens[...] = generate_image_tokens_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 70, in generate_image_tokens_flax
    image_tokens = decode_flax(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/min_dalle_flax.py", line 49, in decode_flax
    image_tokens = decoder.sample_image_tokens(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/transforms.py", line 1246, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 352, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/linen/module.py", line 651, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 254, in sample_image_tokens
    _, image_tokens = lax.scan(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 213, in sample_next_image_token
    logits, keys_state, values_state = self.apply(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 189, in __call__
    decoder_state, (keys_state, values_state) = self.layers(
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 138, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/Users/benjamineaston/.local/share/virtualenvs/min-dalle-2jE6kTVH/lib/python3.8/site-packages/flax/core/axes_scan.py", line 114, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 95, in __call__
    decoder_state, keys_values_state = self.self_attn(
  File "/Users/benjamineaston/LocalDocuments/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py", line 37, in __call__
    keys_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:26
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

5reactions
toabicommented, Jun 28, 2022

Changing the requirements.txt to have flax==0.4.2 seems to fix this issue.

Edit: Running on a M1 Pro in a python 3.9.13 venv.

2reactions
adamrymancommented, Jun 28, 2022

Just a note, I had the same issue, and I resolved by doing the following

pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  • Use the cuda version of torch
pip uninstall torch torchvision torchaudio
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117

Then generate the image using torch

 python image_from_text.py --text="a comfy chair that looks like an avocado" --seed=4 --mega --torch

Edit: toabi’s solution above also works for me without the --torch flag. Though it is very very slow for me, as it does not seem to be able find my GPU without the --torch flag.

Ahhh, I was able to get the GPU working with jax via following their README. https://github.com/google/jax#pip-installation-gpu-cuda Turns out I did not have cuDNN installed. Though I then ran into the error: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 201326592 bytes. Guessing my computer is not beefy enough or such.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Incorrect dtype when read_csv reads a file big enough #1165
The problem is with the type of column. It should be category but on Ray or Dask it appears as object which doesn't...
Read more >
Columns have mixed types. Specify dtype option on import or ...
In this blog post I give three solutions to solve the "DtypeWarning: Columns have mixed types" warning message.
Read more >
XGBOOST Error Running Model - DataFrame.dtypes for data ...
I am trying to run a model which uses XGBOOST algorithm. However, upon execution there are couple of rows which has Nonetype data...
Read more >
How to Fix: KeyError in Pandas - GeeksforGeeks
Pandas KeyError occurs when we try to access some column/row label in our DataFrame that doesn't exist. Usually, this error occurs when you ......
Read more >
Learn Pandas with Pokemons - Kaggle
Explore and run machine learning code with Kaggle Notebooks | Using data from Pokemon with stats.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found