question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Does vmap or lax.scan do caching?

See original GitHub issue

Hi everyone,

I’m asking because I noticed a very weird behaviour in jax-unirep that, even with a good amount of effort trying to narrow down the cause, I cannot seem to find it.

The weird behaviour is minimally reproduced in a github gist, where the output of a jax vmapped function (vmap over “sample” axes) that internally uses a lax.scan (RNN component) oscillates between two values even when the exact same inputs are provided. The behaviour was first reported by @hhefzi on our issue tracker.

I’ve inserted logging statements to identify whether the same Python objects being passed into the top-level function are being propagated down into the functions that do the RNN math (the answer is that they are), hinting to me that the “state” of the program is correct, i.e. the program consistently passes the correct inputs down from the top-level functions being called.

This leaves the JAX parts of jax-unirep that I’m not quite sure how to debug. Hence the motivation for the question: is there caching involved in vmap or lax.scan? If so, that might explain the oscillatory behaviour.

If not, I’m not sure what else could be causing the issue; I’m kind of running dry on ideas, already poured in a few of the best hours of my brainpower to try to get to the root of this issue. Might you all have some alternative ideas?

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

4reactions
mattjjcommented, Jul 21, 2020

Hey Eric!

vmap doesn’t do any caching. lax.scan has caching so that scanning the same function twice over inputs with the same shapes doesn’t recompile; it’s like jit in that way.

The caching is sound so long as the functions involved are pure, i.e. they don’t have side effects. Could there be side effects in the unirep code?

You can disable lax.scan’s caching by commenting out the @cache() decorators here and here. You can further disable all jit caching by commenting out the decorators here and here and here. That might help debug whether caching is an issue.

2reactions
hhefzicommented, Jul 22, 2020

I think this is should work without any additional files needed (with jax-unirep installed). Same oscillatory behavior.

from jax_unirep import fit
from jax_unirep.utils import load_params_1900
from jax_unirep import get_reps
import numpy as np

sequences = ["MKLVIPJ", "MMLVIKJP", "MKLVIJJ"]

params = fit(params=None, sequences = sequences, n_epochs = 10)

mut_seq = sequences[0]

for i in range(0,6):
    print('------Iteration {}------'.format(i))
    print('Default parameters-sum of embeddings: {}'.format(np.sum(get_reps(mut_seq)[0])))
    print('Custom parameters-sum of embeddings: {}'.format(np.sum(get_reps(mut_seq,params=params[0])[0])))

Output: ------Iteration 0------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 220.5883331298828 ------Iteration 1------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 218.6575164794922 ------Iteration 2------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 220.5883331298828 ------Iteration 3------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 218.6575164794922 ------Iteration 4------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 220.5883331298828 ------Iteration 5------ Default parameters-sum of embeddings: 220.5883331298828 Custom parameters-sum of embeddings: 218.65750122070312

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. ... scan() compiles f , so while it can be combined with...
Read more >
What are the URLs for imagery services in The National Map ...
Some imagery services are cached and some are dynamic: USGSImageryOnly (under Base Maps), for example, is a tile cache base map service of...
Read more >
Amazon keeps growing, and so does its cache of data on you
In recent weeks, Amazon has said it will spend billions on two ... The latest line of Roombas, for example, employ sensors that...
Read more >
Common Gotchas in JAX - Colaboratory - Google Colab
This is because JAX now invokes a cached compilation of the function ... lax.scan def func11(arr, extra): ones = jnp.ones(arr.shape)
Read more >
Eliminating for-loops that have carry-over using lax.scan
From the JAX docs, lax.scan replaces a for-loop with carry-over, with some of my ... This will be a demonstration of how to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found