question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Studies on autodiff with JAX

See original GitHub issue

I’ll collect here some tests I’m currently doing using JAX, inspired by #439 .

First, I tried to calculate the plain gradient of a usual linear problem:

import numpy as np
from skfem import *
from jax import jit, grad
from jax.numpy import vectorize


def energy(du0, du1):
    return .5 * (du0 ** 2  + du1 ** 2)

@jit
def jacf(du0, du1):
    return vectorize(grad(energy, (0, 1)))(du0, du1)

m = MeshTri()
basis = InteriorBasis(m, ElementTriP1())

@BilinearForm
def bilinf1(u, v, w):
    from skfem.helpers import dot, grad
    Ju = jacf(*u.grad)
    return dot(Ju, grad(v))

@BilinearForm
def bilinf2(u, v, w):
    from skfem.helpers import dot, grad
    return dot(grad(u), grad(v))

A = bilinf1.assemble(basis)
B = bilinf2.assemble(basis)

This gives

In [1]: (A - B).todense()                                                       
Out[1]: 
matrix([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
kinnalacommented, Jul 21, 2020

Now I was able to implement ex10.py using JAX:

# Minimal surface problem

import numpy as np
from skfem import *
from jax import jit, grad
from jax.numpy import vectorize
import jax.numpy as jnp
from jax.ops import index_add


def F(du0, du1):
    return jnp.sqrt(1 + du0 ** 2 + du1 ** 2)

def jac_eval(du0, du1):
    out = np.zeros((2,) + du0.shape)
    for i in range(2):
        out[i] = vectorize(grad(F, i))(du0, du1)
    return out

def hess_eval(du0, du1):
    out = np.zeros((2, 2) + du0.shape)
    for i in range(2):
        for j in range(2):
            out[i, j] = vectorize(grad(grad(F, i), j))(du0, du1)
    return out

m = MeshTri()
m.refine(5)
basis = InteriorBasis(m, ElementTriP1())

@LinearForm
def linf_rhs(v, w):
    from skfem.helpers import dot, grad
    Jw = np.array(jac_eval(*w['prev'].grad))
    return -dot(Jw, grad(v))

@BilinearForm
def bilinf_hess(u, v, w):
    from skfem.helpers import ddot, grad, prod
    Hw = np.array(hess_eval(*w['prev'].grad))
    return ddot(Hw, prod(grad(u), grad(v)))

x = np.zeros(basis.N)
I = m.interior_nodes()
D = m.boundary_nodes()
x[D] = np.sin(np.pi * m.p[0, D])

for itr in range(100):
    prev = basis.interpolate(x)
    K = asm(bilinf_hess, basis, prev=prev)
    f = asm(linf_rhs, basis, prev=prev)
    x_prev = x.copy()
    x += .7 * solve(*condense(K, f, I=I))
    if np.linalg.norm(x - x_prev) < 1e-7:
        break
    print(np.linalg.norm(x - x_prev))

from skfem.visuals.matplotlib import plot3, show
plot3(m, x)
show()

There is some overhead on autodiff which I think can be improved by learning JAX better but it’s not huge. Result is correct.

1reaction
kinnalacommented, Jul 21, 2020

I think this is more of an issue on “how to use JAX”. Trying to use vectorize on the Hessian results in

ValueError: output shape (2,) does not match core dimensions () on vectorized function with excluded=frozenset() and signature=None

I was trying to run

import numpy as np
from skfem import *
from jax import jit, grad, jacrev, jacfwd
from jax.numpy import vectorize
import jax.numpy as jnp

def F(du0, du1):
    return jnp.sqrt(1 + du0 ** 2 + du1 ** 2)

vectorize(jacfwd(jacrev(F, (0, 1)), (0, 1)))(0.1, 0.1)
Read more comments on GitHub >

github_iconTop Results From Across the Web

The Autodiff Cookbook - JAX documentation
JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you...
Read more >
Automatic differentiation in scientific programming with jax
We use derivatives extensively in science and engineering. Historically derivatives have been a challenge in computer programs.
Read more >
Automatic Differentiation in JAX | Kaggle
As you noticed in many deep learning courses and books, backpropagation is specialized version of reverse autodiff and it is used to differentiate...
Read more >
The magic behind autodiff - GitHub Pages
Tutorials on automatic differentiation and JAX. ... and how to use it in JAX, by restricting our attention to scalar values and forward-mode...
Read more >
JAX: A Machine Learning Research Library - GradFUTURES
In particular, these transformations include automatic differentiation, automatic batching, end-to-end-compilation (via XLA), and even ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found