question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

bug in jax.nn.glu

See original GitHub issue

There is an issue with jax.nn.glu in the latest release

https://github.com/google/jax/blob/728e4fd3fad334e551dcd1a13b47b8d51aae5a9f/jax/_src/nn/functions.py#L261

Here the static_argnames should be axis. Right now calling glu results in a TracerIntegerConversionError

import jax
import jax.numpy as jnp

jax.nn.glu(jnp.ones((1, 4)), axis=1)

full error messages/tracebacks.

TracerIntegerConversionError              Traceback (most recent call last)
Input In [8], in <module>
----> 1 jax.nn.glu(jnp.ones((1, 4)), 1)

    [... skipping hidden 14 frame]

File ~/transformers-env/lib/python3.8/site-packages/jax/_src/nn/functions.py:269, in glu(x, axis)
    261 @partial(jax.jit, static_argnames=("glu",))
    262 def glu(x: Array, axis: int = -1) -> Array:
    263   """Gated linear unit activation function.
    264 
    265   Args:
    266     x : input array
    267     axis: the axis along which the split should be computed (default: -1)
    268   """
--> 269   size = x.shape[axis]
    270   assert size % 2 == 0, "axis size must be divisible by 2"
    271   x1, x2 = jnp.split(x, 2, axis)

File ~/transformers-env/lib/python3.8/site-packages/jax/core.py:473, in Tracer.__index__(self)
    472 def __index__(self):
--> 473   raise TracerIntegerConversionError(self)

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
patrickvonplatencommented, Mar 17, 2022

Do you think you can do a patch release for this one as we have a couple of Flax/JAX models in Transformers failing to due this on master at the moment and are also looking into doing another release soon.

1reaction
patrickvonplatencommented, Mar 18, 2022

Thanks a lot for helping so quickly!

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.nn.glu — JAX documentation
jax.nn.glu(x, axis=- 1)[source]#. Gated linear unit activation function. Parameters. x ( Any ) – input array. axis ( int ) – the axis...
Read more >
map of pmap IndexError bug · Issue #2822 · google/jax - GitHub
Map of pmap seems to fail due to a strange shape bug. See repro below. import jax.numpy as np from jax import lax,...
Read more >
Revision of the Plant Bug Genus Tytthus (Hemiptera ...
The phyline plant bug genus Tytthus Fieber, previously containing 19 ... This species was described from Jacksonville, Florida ( Knight 1931 ) ...
Read more >
AQUATIC AND SEMI-AQUATIC HETEROPTERA OF FLORIDA
Bugs may be glued, on their right side, to points using regular white glue (such as Elmer's), clear nail polish, Canada balsam or...
Read more >
071C Bus Price List 12.2021 - NC DOA
... BTP, BTQ, BTR, BTS, BTT, BTU, BTV, BTW, BTX, BTY, BTZ, BUA, BUB, BUC, BUD, BUE, BUF, BUG, BUH, BUI, BUJ, BUK,...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found