question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Prevent custom calls with side effects to be optimized out

See original GitHub issue

I am currently experimenting with implementing MPI send / recv as custom XLA calls.

It works fine in most cases, but a function like this leads to a deadlock:

@jax.jit
def send_recv(x):
    if rank == 0:
        x = Recv(x, comm=comm)
    else:
        Send(x, 0, comm=comm)
        # works if doing x = Send(x, 0, comm=comm)
    return x

I guess this is because the return value of Send is not used in the computational graph, so the whole call is optimized away, despite having side effects.

Is there a way to prevent this?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:2
  • Comments:22 (18 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Jul 24, 2020

Yes, that’s right for now, but with #3370 neither tokens nor tie_in will be necessary.

1reaction
mattjjcommented, Jul 23, 2020

Here’s a more self-contained example if you prefer:

from jax.core import Primitive

def Send(x):
  return send_p.bind(x)

send_p = Primitive('send')
send_p.def_abstract_eval(lambda x: None)


def f(x):
  Send(x)

from jax import make_jaxpr
print(make_jaxpr(f)(2))
{ lambda  ; a.
  let
  in () }

You can change the make_jaxpr to a jit and add a print(built.as_hlo_text()) after this line in xla.py if you want to convince yourself that XLA will never see the bound primitive.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Force compiler to not optimize side-effect-less statements
I'm using Visual Studio 2008 / G++ (3.4.4). Edit. To clarify, I would like to have all optimizations maxed out, to get good...
Read more >
Regression: CustomCall with side effect gets optimized out ...
Since JAX 0.3.15, our MPI CustomCalls are getting optimized out even though they are declared as having side effects (has_side_effect=True).
Read more >
MSC06-C. Beware of compiler optimizations
This technique prevents some compilers from optimizing out the call to memset() but does not work for all implementations. For example, the MIPSpro...
Read more >
[llvm-dev] Preventing function call from being optimized out in ...
I am adding function calls to an LLVM link-time optimization (LTO) pass, using the IRBuilder::CreateCall method. I want these calls to remain in...
Read more >
Except, of course, the the compiler can and will optimize away ...
The linker (or a JIT in your C runtime) is allowed to remove calls to the function, if it can prove that it...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found