question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How alpa lower control flow jaxpr into XLA HLO?

See original GitHub issue

for example,cond:

>>> from jax import lax
>>>
>>> def func7(arg):
...   return lax.cond(arg >= 0.,
...                   lambda xtrue: xtrue + 3.,
...                   lambda xfalse: xfalse - 3.,
...                   arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
    b:bool[] = ge a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
        { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
      )
      linear=(False,)
    ] c a
  in (d,) }

In this situation, how alpa deal with cond eqns if we want to lower it into XLA HLO? @zhisbug

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
merrymercycommented, Jun 7, 2022

This is not supported yet, but we are working on supporting it. see also https://github.com/alpa-projects/alpa/issues/400

0reactions
merrymercycommented, Jun 11, 2022

I think the original question is answered. We can close this issue, move control flow related discussion to #400, and move other questions to new issues.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How Jax lower jaxpr control flow into XLA? #10991 - GitHub
Discussed in #10990 ... how jaxpr_subcomp deal with control flow primitives, when these eqns are lowered into XLA HLO? For example, cond, while, ......
Read more >
Just In Time Compilation with JAX
We will discuss the jax.jit() transform, which will perform Just In Time (JIT) compilation of a JAX Python function so it can be...
Read more >
XLA Architecture - TensorFlow
Compile subgraphs to reduce the execution time of short-lived Ops to eliminate overhead from the TensorFlow runtime, fuse pipelined operations ...
Read more >
Compiling machine learning programs via high-level tracing
To generate code, JAX translates the trace into XLA HLO, an intermediate language that models highly accelerable array-level numerical programs. Broadly ...
Read more >
Deep Dive into XLA (Draft) - Minjae's Blog
An HLO pass that canonicalizes the dimension numbers of all top-level convolutions in the given module. In order to hit the fast path...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found