dev timeline for more batched ops
See original GitHub issueHi JAX Team,
I’m interested in using your library, e.g. to use gradients to optimize functions whose terms include functions g() of jacobians of functions f(). There’s no good way in PyTorch to do this for f() with output dim > 1, and certainly no good way to do this in a batched setting.
In JAX, i’m able to very easily get e.g. the trace of the jacobian of a plain relu neural net. So cool! Things break down when I try to use vmap to get a batched version of the computation. I’m wondering, do you have a rough guess at the timeline for when I might be able to vmap in this kind of setting? Note: I’ve successfully used JAX to do vmap’s of functions g() of jacobians of functions f() for simple choices of f()
import jax.numpy as np
from jax import jit, jacrev, vmap
W = np.eye(3)
b = np.zeros(3)
def relu(x):
return np.maximum(0,x)
def NN(x):
return relu(np.dot(W,x) + b)
Jx_NN = jit(jacrev(NN))
# By the way, is there a recommended way to compose already JIT'ed functions?
def trace_Jx_NN(x):
J = Jx_NN(x)
return np.trace(J)
x1 = np.array([2.0,2.0,2.0])
x2 = np.array([3.0,3.0,3.0])
X = np.vstack((x1,x2))
# this works
print(trace_Jx_NN(x1))
# this line executes
batched_trace_Jx_NN = jit(vmap(trace_Jx_NN,in_axes=(0)))
# raises "NotImplementedError # TODO(schsam, mattjj): Handle more cases."
print(batched_trace_Jx_NN(X))
Thanks! mark
Issue Analytics
- State:
- Created 5 years ago
- Comments:5 (4 by maintainers)
Top Results From Across the Web
Modern Batch: The Ops in DevOps fully evolved – “Jobs-as ...
My previous installment in this series discussed the challenges of managing batch for modern technology with the added complexity of meeting ...
Read more >Priority-based batch scheduling - Dynamics 365
Create a batch job In the Scheduled start date/time field, enter a date and time. In the Run by field, select the users...
Read more >Performing large-scale batch operations on Amazon S3 objects
You can use S3 Batch Operations through the AWS Management Console, AWS CLI, Amazon SDKs, or REST API. Use S3 Batch Operations to...
Read more >DevOps tech: Trunk-based development - Google Cloud
Timelines for multiple long-lived branches, showing complex merge paths and many points at. In this approach, developers make changes to long-lived branches.
Read more >Batch processing - Wikipedia
While users are required to submit the jobs, no other interaction by the user is required to process the batch. Batches may automatically...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Oh no! Actually, I’ve never worked in industry, so I didn’t realize the “dev timeline” phrasing had such connotations, sorry! I just meant to ask, is it something you had particular plans about 😃
Yes, your completion of my code example was correct. I will be sure to post a fully-reproducible example next time.
Thanks so much for pursuing this case! Your message above about handling batching clarifies some things in the code for me, and I’ll try to understand it a bit more.
More generally, I’m looking forward to reading the code base more in the next couple of weeks and contributing where possible. I’m doing grad school with a PL advisor and stat/ML advisor, and was thinking JAX is a good place to find some cool research problems!
#312 should cover all the cases for batching
lax.select
, including this issue, and we can probably merge it in the next hour. That’s a pretty good dev timeline!(I should add another commit or two with more test cases, and maybe more efficient transpose-avoiding lowerings for special cases, like the special case we had covered before.)