Does vmap or lax.scan do caching?
See original GitHub issueHi everyone,
I’m asking because I noticed a very weird behaviour in jax-unirep that, even with a good amount of effort trying to narrow down the cause, I cannot seem to find it.
The weird behaviour is minimally reproduced in a github gist, where the output of a jax vmap
ped function (vmap
over “sample” axes) that internally uses a lax.scan
(RNN component) oscillates between two values even when the exact same inputs are provided. The behaviour was first reported by @hhefzi on our issue tracker.
I’ve inserted logging statements to identify whether the same Python objects being passed into the top-level function are being propagated down into the functions that do the RNN math (the answer is that they are), hinting to me that the “state” of the program is correct, i.e. the program consistently passes the correct inputs down from the top-level functions being called.
This leaves the JAX parts of jax-unirep that I’m not quite sure how to debug. Hence the motivation for the question: is there caching involved in vmap
or lax.scan
? If so, that might explain the oscillatory behaviour.
If not, I’m not sure what else could be causing the issue; I’m kind of running dry on ideas, already poured in a few of the best hours of my brainpower to try to get to the root of this issue. Might you all have some alternative ideas?
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (2 by maintainers)
Hey Eric!
vmap
doesn’t do any caching.lax.scan
has caching so that scanning the same function twice over inputs with the same shapes doesn’t recompile; it’s likejit
in that way.The caching is sound so long as the functions involved are pure, i.e. they don’t have side effects. Could there be side effects in the unirep code?
You can disable
lax.scan
’s caching by commenting out the@cache()
decorators here and here. You can further disable alljit
caching by commenting out the decorators here and here and here. That might help debug whether caching is an issue.I think this is should work without any additional files needed (with jax-unirep installed). Same oscillatory behavior.
Output: ------Iteration 0------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 220.5883331298828 ------Iteration 1------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 218.6575164794922 ------Iteration 2------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 220.5883331298828 ------Iteration 3------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 218.6575164794922 ------Iteration 4------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 220.5883331298828 ------Iteration 5------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 218.65750122070312