question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.autograd very slow, and jax.numpy slower than standard numpy

See original GitHub issue

I am not sure if this is the correct forum for this issue, so please correct me if I am wrong. I just started using jax in order to implement its GPU-backed autograd capability. However, for my function(s), the gradient is either extremely slow to calculate, or worse, makes the kernel crash (always the case with jit on). Moreover, even the original scalar function (and functions used within) is many times slower when used with jax.numpy rather than standard numpy. Using jit significantly quickens the scalar function, but that does not affect the performance of the gradient.

Can someone please point out to me any significant bottlenecks in my code, perhaps things that are unsuited to jax? The code is given below. The scalar function is bendE__triangles(), which calls upon bendEdensity__triangles() and getCurvatures(). This last contains all the complicated operations. The inputs are all standard numpy ndarrays.

import jax.numpy as np
from jax import grad, vmap, jit
from jax.config import config
config.update("jax_enable_x64", True)
@jit 
def getCurvatures(xyz, triSimplicesInterior, triNL, W):
    ## initialise arrays
    xyzTri = np.zeros((6,3), dtype=np.float64) 
    zProjected = np.zeros(6, dtype=np.float64) 
    centroid = np.zeros(3, dtype=np.float64) 
    aVector = np.zeros(6, dtype=np.float64) 
    curvVector = []
    
    for t, triangle in enumerate(triSimplicesInterior):

        p1, p2, p3 = triangle[0], triangle[1], triangle[2]
        
        x1, y1, z1 = xyz[p1][0], xyz[p1][1], xyz[p1][2]
        x2, y2, z2 = xyz[p2][0], xyz[p2][1], xyz[p2][2]
        x3, y3, z3 = xyz[p3][0], xyz[p3][1], xyz[p3][2]
        
        N1 = np.cross(xyz[p2]-xyz[p1] , xyz[p3]-xyz[p2])
        n1 = N1 / np.linalg.norm(N1)
  
        xyzTri = np.vstack((xyz[triangle], xyz[triNL[t]])) 
        
        centroid = np.array([(x1+x2+x3)/3, (y1+y2+y3)/3, (z1+z2+z3)/3])
        zProjected = np.dot(xyzTri-centroid, n1)

        aVector = np.dot(W[t], zProjected)

        c11 = 2*aVector[3]
        c12 = aVector[4]
        c22 = 2*aVector[5]
        curvVector.append(np.array([c11,c12,c22]))
    
    curvVector = np.array(curvVector)
    return curvVector

@jit
def bendEdensity__triangles(xyz0, triSimplicesInterior, triNL, W, triAreasInterior):
    
    xyz = np.reshape(xyz0, (xyz0.size//3, 3))

    curvMatrix = getCurvatures(xyz, triSimplicesInterior, triNL, W)

    c11, c22 = 0., 0.
    energy = 0.
    
    bendEdensity = []
    for t in range(curvMatrix.shape[0]):
        c11, c22 = curvMatrix[t][0], curvMatrix[t][2]
        bendEdensity.append(0.5*(c11+c22)**2) # = 1/2(TraceC)**2 
    
    bendEdensity = np.array(bendEdensity)
    return bendEdensity

 @jit
 def bendE__triangles(xyz0, triSimplicesInterior, triNL, W, triAreasInterior):
    
    xyz = np.reshape(xyz0, (xyz0.size//3, 3))

    bendEdensity = bendEdensity__triangles(xyz, triSimplicesInterior, triNL, W, triAreasInterior)
    energy = np.sum(bendEdensity * triAreasInterior) # = Sum( 1/2(TraceC)**2 * area) 

    return energy


## scalar energy function
def bendE(xyz):
        return bendE__triangles(xyz, triSimplicesInterior, triNLinterior, W0, triAreasInterior)

## gradient using jax.grad
der_energy = grad(bendE)

print(bendE(xyz0)) # speeded up by using jit
print(der_energy (xyz0)) # either takes long time, or kills the kernel.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:2
  • Comments:13 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
sschoenholzcommented, Sep 10, 2019

You can even drop the in_axes, vmap(partialFunc)(triSimplicesInterior, triNLinterior).

1reaction
jekbradburycommented, Sep 10, 2019

[partialFunc(x, y) for x,y in zip(triSimplicesInterior, triNLinterior)] corresponds to vmap(partialFunc, in_axes=(0, 0))(triSimplicesInterior, triNLinterior).

Read more comments on GitHub >

github_iconTop Results From Across the Web

Why is this function slower in JAX vs numpy? - Stack Overflow
I tested the problem with perfplot across a range of problem sizes. Result: jax is ever so slightly faster.
Read more >
JAX Quickstart - JAX documentation - Read the Docs
What's new is that JAX uses XLA to compile and run your NumPy code on accelerators, ... That's slower because it has to...
Read more >
JAX - The Examples Book
JAX is a Google research project built upon native Python and NumPy functions to improve machine research learning. The official JAX page describes...
Read more >
[D] Why Learn Jax? : r/MachineLearning - Reddit
The main reason it's starting to gain traction is because you can often replace numpy with jax.numpy and expect auto-differentiation for all ...
Read more >
JAX 是一个 TensorFlow 的简化库,它结合了 Autograd 和 XLA
What's new is that JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found