Performance issue
See original GitHub issueHello everyone,
I’m testing JAX framework speed and for that purpose, I wrote the Kalman filter. The gold standard is a Matlab version that performs an algorithm for 50k number of steps in ~0.1 sec. The TensorFlow 2.1 with tf.function
gives ~2.0 sec. Unfortunately, JAX version hangs forever.
The JAX version of Kalman filter in Colab is here and TensorFlow 2.1 version is here. I would appreciate any help in resolving the issue or understanding what is wrong with the code.
Thanks, Artem
Issue Analytics
- State:
- Created 4 years ago
- Comments:8 (2 by maintainers)
Top Results From Across the Web
Dealing with Performance Problems
Types of Performance Problems ; Quantity of work (untimely completion, limited production). Poor prioritizing, timing, scheduling; Lost time ; Quality of work ( ......
Read more >Handling Performance Issues With Grace | Monster.com
Low Productivity or Late Completion – Make sure you've been clear about the requirements and expectations of the job. · Poor Quality of...
Read more >9 Examples of a Performance Issue - Simplicable Guide
A performance issue is a failure to meet the basic requirements of a job. They are based on reasonable expectations of behavior and...
Read more >AMD 7900XTX Overclocking can lead to LESS ... - YouTube
AMD 7900XTX Overclocking can lead to LESS performance... Here's how. 315K views 2 days ago. JayzTwoCents. JayzTwoCents. 3.81M subscribers.
Read more >Top 5 Common Performance Problems - HRCI
Top 5 Common Performance Problems · Shallow Work · Inability to Prioritize · False Sense of Urgency · Productive Procrastination · Low-Quality Output....
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think you need to use
block_until_ready
in your colab microbenchmark. Something like:This is because JAX dispatches operations asynchronously (usually a good thing!), so you’re only measuring the dispatch time here. See https://jax.readthedocs.io/en/latest/async_dispatch.html for more details.
Regarding the unrolled for-loop, this is an artifact of how JAX traces functions to be jitted. It’s a little nuanced, but basically, JAX only “sees” certain overloaded operations (e.g. jax.numpy functions), and Python doesn’t provide a good way to overload control flow, such as for-loops. The result is the loop runs as usual (JAX doesn’t “see” it at all) and JAX traces the operations inside each iteration, but can’t tell they’re all from the same loop body. That’s why we provide control flow operations like
scan
, so JAX can “see” the loop itself. This talk goes over JAX’s tracing in more detail if you’re curious.I am closing this, assuming that the issue was clarified. One lesson is that it would be helpful for the documentation to make it very clear that JAX is meant to optimize certain kinds of Python codes only.