question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Regression in performance of samplers

See original GitHub issue

With jax versions > 0.1.24, random samplers are slow. I think that it is due to recent changes on how jit works with static_args/kwargs. A script to reproduce,

import time
from jax import random

t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)
t = time.time()
random.normal(random.PRNGKey(0), shape=(1,))
print(time.time() - t)

which returns 0.12923526763916016 and 0.12221717834472656.

However, if we wrap these samplers in some function, then it is fast. For example,

def f():
    return random.normal(random.PRNGKey(0), shape=(1,))

t = time.time()
f()
print(time.time() - t)
t = time.time()
f()
print(time.time() - t)

will return 0.12787413597106934 and 0.0010831356048583984.

I think that there is a small bug elsewhere which forces the sampler recompile. If this is an expected behaviour, then which function we should use to wrap these samplers to make it not recompiled? cc @neerajprad

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

3reactions
mattjjcommented, May 10, 2019

This was an issue with all uses of static_argnums, though it only cropped up recently in random.py because of how we changed random.py following the kwargs-handling change.

#692 should fix it.

The use of static_argnums was only resulting in compilation cache hits based on object identity equivalence. That meant that random.py functions were only caching on object identity equivalence (i.e. x is y instead of equality x == y) of their shape parameters. In your first code example, which passed in (1,) the first time and a fresh literal (1,) the second time, those two objects didn’t have the same identity (though they would have compared equal). In your second code example, the same literal was used twice and so the cache was being hit on object identity.

The fix in #692 is just to make static_argnums use equality checks when possible (i.e. when argument objects corresponding to static argnums have __hash__ and, by implication, __eq__), and fall back to object identity equivalence when not. This should still allow arbitrary objects to be passed as static args (i.e. even unhashable ones), but also let us get cache hits where appropriate, especially for shape args like in random.py.

Does that make sense? What do you think?

I can’t thank you enough for spotting this and providing such a clear repro. Your sleuthing made this an easy fix, and without it we would have missed it for who knows how long!

2reactions
mattjjcommented, May 10, 2019

JAX bugs fixed in 30 minutes or less, or your money* back!

*JAX costs $0

Read more comments on GitHub >

github_iconTop Results From Across the Web

Assessment of Regression Models Performance • performance
Model performance summaries​​ model_performance() computes indices of model performance for regression models. Depending on the model object, typical indices ...
Read more >
On the Use of Regression Calibration in a Complex Sampling ...
Regression calibration is the most widely used method to adjust regression parameter estimates for covariate measurement error.
Read more >
A solution to minimum sample size for regressions
Estimating sample size for mixed-effects models is complicated because it depends on having enough random factor levels and samples within those ...
Read more >
Automatic outlier sample detection based on regression ...
Outlier samples are detected by comprehensively considering all regression models. Robust automatic outlier sample detection can be achieved. ...
Read more >
Regression and time series model selection in small samples
Some key words: AIC; Asymptotic efficiency; Kullback-Leibler information. 1. INTRODUCTION. The problems of regression and autoregressive model selection are ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found