question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

dev timeline for more batched ops

See original GitHub issue

Hi JAX Team,

I’m interested in using your library, e.g. to use gradients to optimize functions whose terms include functions g() of jacobians of functions f(). There’s no good way in PyTorch to do this for f() with output dim > 1, and certainly no good way to do this in a batched setting.

In JAX, i’m able to very easily get e.g. the trace of the jacobian of a plain relu neural net. So cool! Things break down when I try to use vmap to get a batched version of the computation. I’m wondering, do you have a rough guess at the timeline for when I might be able to vmap in this kind of setting? Note: I’ve successfully used JAX to do vmap’s of functions g() of jacobians of functions f() for simple choices of f()

import jax.numpy as np
from jax import jit, jacrev, vmap

W = np.eye(3)
b = np.zeros(3)

def relu(x):
    return np.maximum(0,x)

def NN(x):
    return relu(np.dot(W,x) + b)

Jx_NN = jit(jacrev(NN))

# By the way, is there a recommended way to compose already JIT'ed functions?
def trace_Jx_NN(x):
    J = Jx_NN(x)
    return np.trace(J) 


x1 = np.array([2.0,2.0,2.0])
x2 = np.array([3.0,3.0,3.0])
X = np.vstack((x1,x2))

# this works
print(trace_Jx_NN(x1)) 

# this line executes
batched_trace_Jx_NN = jit(vmap(trace_Jx_NN,in_axes=(0)))

# raises "NotImplementedError  # TODO(schsam, mattjj): Handle more cases."
print(batched_trace_Jx_NN(X)) 

Thanks! mark

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
marikgoldsteincommented, Feb 3, 2019

Oh no! Actually, I’ve never worked in industry, so I didn’t realize the “dev timeline” phrasing had such connotations, sorry! I just meant to ask, is it something you had particular plans about 😃

Yes, your completion of my code example was correct. I will be sure to post a fully-reproducible example next time.

Thanks so much for pursuing this case! Your message above about handling batching clarifies some things in the code for me, and I’ll try to understand it a bit more.

More generally, I’m looking forward to reading the code base more in the next couple of weeks and contributing where possible. I’m doing grad school with a PL advisor and stat/ML advisor, and was thinking JAX is a good place to find some cool research problems!

1reaction
mattjjcommented, Feb 3, 2019

#312 should cover all the cases for batching lax.select, including this issue, and we can probably merge it in the next hour. That’s a pretty good dev timeline!

(I should add another commit or two with more test cases, and maybe more efficient transpose-avoiding lowerings for special cases, like the special case we had covered before.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Modern Batch: The Ops in DevOps fully evolved – “Jobs-as ...
My previous installment in this series discussed the challenges of managing batch for modern technology with the added complexity of meeting ...
Read more >
Priority-based batch scheduling - Dynamics 365
Create a batch job​​ In the Scheduled start date/time field, enter a date and time. In the Run by field, select the users...
Read more >
Performing large-scale batch operations on Amazon S3 objects
You can use S3 Batch Operations through the AWS Management Console, AWS CLI, Amazon SDKs, or REST API. Use S3 Batch Operations to...
Read more >
DevOps tech: Trunk-based development - Google Cloud
Timelines for multiple long-lived branches, showing complex merge paths and many points at. In this approach, developers make changes to long-lived branches.
Read more >
Batch processing - Wikipedia
While users are required to submit the jobs, no other interaction by the user is required to process the batch. Batches may automatically...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found