Regression in performance of samplers
See original GitHub issueWith jax versions > 0.1.24, random samplers are slow. I think that it is due to recent changes on how jit works with static_args/kwargs. A script to reproduce,
import time
from jax import random
t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)
t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)
which returns 0.12923526763916016
and 0.12221717834472656
.
However, if we wrap these samplers in some function, then it is fast. For example,
def f():
return random.normal(random.PRNGKey(0), shape=(1,))
t = time.time()
f()
print(time.time() - t)
t = time.time()
f()
print(time.time() - t)
will return 0.12787413597106934
and 0.0010831356048583984
.
I think that there is a small bug elsewhere which forces the sampler recompile. If this is an expected behaviour, then which function we should use to wrap these samplers to make it not recompiled? cc @neerajprad
Issue Analytics
- State:
- Created 4 years ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
Assessment of Regression Models Performance • performance
Model performance summaries model_performance() computes indices of model performance for regression models. Depending on the model object, typical indices ...
Read more >On the Use of Regression Calibration in a Complex Sampling ...
Regression calibration is the most widely used method to adjust regression parameter estimates for covariate measurement error.
Read more >A solution to minimum sample size for regressions
Estimating sample size for mixed-effects models is complicated because it depends on having enough random factor levels and samples within those ...
Read more >Automatic outlier sample detection based on regression ...
Outlier samples are detected by comprehensively considering all regression models. Robust automatic outlier sample detection can be achieved. ...
Read more >Regression and time series model selection in small samples
Some key words: AIC; Asymptotic efficiency; Kullback-Leibler information. 1. INTRODUCTION. The problems of regression and autoregressive model selection are ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This was an issue with all uses of
static_argnums
, though it only cropped up recently in random.py because of how we changed random.py following the kwargs-handling change.#692 should fix it.
The use of
static_argnums
was only resulting in compilation cache hits based on object identity equivalence. That meant that random.py functions were only caching on object identity equivalence (i.e.x is y
instead of equalityx == y
) of their shape parameters. In your first code example, which passed in(1,)
the first time and a fresh literal(1,)
the second time, those two objects didn’t have the same identity (though they would have compared equal). In your second code example, the same literal was used twice and so the cache was being hit on object identity.The fix in #692 is just to make
static_argnums
use equality checks when possible (i.e. when argument objects corresponding to static argnums have__hash__
and, by implication,__eq__
), and fall back to object identity equivalence when not. This should still allow arbitrary objects to be passed as static args (i.e. even unhashable ones), but also let us get cache hits where appropriate, especially for shape args like in random.py.Does that make sense? What do you think?
I can’t thank you enough for spotting this and providing such a clear repro. Your sleuthing made this an easy fix, and without it we would have missed it for who knows how long!
JAX bugs fixed in 30 minutes or less, or your money* back!
*JAX costs $0