jax.autograd very slow, and jax.numpy slower than standard numpy
See original GitHub issueI am not sure if this is the correct forum for this issue, so please correct me if I am wrong. I just started using jax in order to implement its GPU-backed autograd capability. However, for my function(s), the gradient is either extremely slow to calculate, or worse, makes the kernel crash (always the case with jit
on). Moreover, even the original scalar function (and functions used within) is many times slower when used with jax.numpy
rather than standard numpy. Using jit
significantly quickens the scalar function, but that does not affect the performance of the gradient.
Can someone please point out to me any significant bottlenecks in my code, perhaps things that are unsuited to jax? The code is given below. The scalar function is bendE__triangles()
, which calls upon bendEdensity__triangles()
and getCurvatures()
. This last contains all the complicated operations. The inputs are all standard numpy ndarrays.
import jax.numpy as np
from jax import grad, vmap, jit
from jax.config import config
config.update("jax_enable_x64", True)
@jit
def getCurvatures(xyz, triSimplicesInterior, triNL, W):
## initialise arrays
xyzTri = np.zeros((6,3), dtype=np.float64)
zProjected = np.zeros(6, dtype=np.float64)
centroid = np.zeros(3, dtype=np.float64)
aVector = np.zeros(6, dtype=np.float64)
curvVector = []
for t, triangle in enumerate(triSimplicesInterior):
p1, p2, p3 = triangle[0], triangle[1], triangle[2]
x1, y1, z1 = xyz[p1][0], xyz[p1][1], xyz[p1][2]
x2, y2, z2 = xyz[p2][0], xyz[p2][1], xyz[p2][2]
x3, y3, z3 = xyz[p3][0], xyz[p3][1], xyz[p3][2]
N1 = np.cross(xyz[p2]-xyz[p1] , xyz[p3]-xyz[p2])
n1 = N1 / np.linalg.norm(N1)
xyzTri = np.vstack((xyz[triangle], xyz[triNL[t]]))
centroid = np.array([(x1+x2+x3)/3, (y1+y2+y3)/3, (z1+z2+z3)/3])
zProjected = np.dot(xyzTri-centroid, n1)
aVector = np.dot(W[t], zProjected)
c11 = 2*aVector[3]
c12 = aVector[4]
c22 = 2*aVector[5]
curvVector.append(np.array([c11,c12,c22]))
curvVector = np.array(curvVector)
return curvVector
@jit
def bendEdensity__triangles(xyz0, triSimplicesInterior, triNL, W, triAreasInterior):
xyz = np.reshape(xyz0, (xyz0.size//3, 3))
curvMatrix = getCurvatures(xyz, triSimplicesInterior, triNL, W)
c11, c22 = 0., 0.
energy = 0.
bendEdensity = []
for t in range(curvMatrix.shape[0]):
c11, c22 = curvMatrix[t][0], curvMatrix[t][2]
bendEdensity.append(0.5*(c11+c22)**2) # = 1/2(TraceC)**2
bendEdensity = np.array(bendEdensity)
return bendEdensity
@jit
def bendE__triangles(xyz0, triSimplicesInterior, triNL, W, triAreasInterior):
xyz = np.reshape(xyz0, (xyz0.size//3, 3))
bendEdensity = bendEdensity__triangles(xyz, triSimplicesInterior, triNL, W, triAreasInterior)
energy = np.sum(bendEdensity * triAreasInterior) # = Sum( 1/2(TraceC)**2 * area)
return energy
## scalar energy function
def bendE(xyz):
return bendE__triangles(xyz, triSimplicesInterior, triNLinterior, W0, triAreasInterior)
## gradient using jax.grad
der_energy = grad(bendE)
print(bendE(xyz0)) # speeded up by using jit
print(der_energy (xyz0)) # either takes long time, or kills the kernel.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:13 (6 by maintainers)
You can even drop the in_axes,
vmap(partialFunc)(triSimplicesInterior, triNLinterior)
.[partialFunc(x, y) for x,y in zip(triSimplicesInterior, triNLinterior)]
corresponds tovmap(partialFunc, in_axes=(0, 0))(triSimplicesInterior, triNLinterior)
.