`jnp.floor_divide()` returns incorrect results when used with `scan`
See original GitHub issueimport jax
from jax import lax
import jax.numpy as jnp
x = jnp.array([2, 2, -3])
def reduce_for_loop(func, arr):
val = arr[0]
for i in range(1, len(arr)):
val = func(val, arr[i])
return val
def reduce_scan(func, arr):
val = arr[0]
body_func = lambda i, val: func(val, arr[i])
return lax.fori_loop(1, len(arr), body_func, val)
print(reduce_for_loop(jnp.floor_divide, x)) # prints -1
print(reduce_scan(jnp.floor_divide, x)) # prints 0
with jax.disable_jit():
print(reduce_scan(jnp.floor_divide, x)) # prints -1
Came up when working on #9529
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (5 by maintainers)
Top Results From Across the Web
Printf is returning with wrong numbers inputted from scanf
A similar thing happens when a character is inputted, only the printf function displays a square instead. Is there a problem with the...
Read more >scanf() — Read Data - IBM
The scanf() function reads data from the standard input stream stdin into the locations given by each entry in argument-list.
Read more >Psychotropic and neurotropic activity1 | SpringerLink
Nikodijevic et al. (1991) studied the behavioral effects of A1- and A2-selective adenosine agonists and antagonists in mice using a Digiscan activity ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Fixed with https://github.com/tensorflow/tensorflow/commit/6d6d7fece1752523699993d7a30ee7790e84adba. I will leave it to Jake to close this issue.
Looks like this is fixed in the jaxlib 0.3.14 release