Studies on autodiff with JAX
See original GitHub issueI’ll collect here some tests I’m currently doing using JAX, inspired by #439 .
First, I tried to calculate the plain gradient of a usual linear problem:
import numpy as np
from skfem import *
from jax import jit, grad
from jax.numpy import vectorize
def energy(du0, du1):
return .5 * (du0 ** 2 + du1 ** 2)
@jit
def jacf(du0, du1):
return vectorize(grad(energy, (0, 1)))(du0, du1)
m = MeshTri()
basis = InteriorBasis(m, ElementTriP1())
@BilinearForm
def bilinf1(u, v, w):
from skfem.helpers import dot, grad
Ju = jacf(*u.grad)
return dot(Ju, grad(v))
@BilinearForm
def bilinf2(u, v, w):
from skfem.helpers import dot, grad
return dot(grad(u), grad(v))
A = bilinf1.assemble(basis)
B = bilinf2.assemble(basis)
This gives
In [1]: (A - B).todense()
Out[1]:
matrix([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
The Autodiff Cookbook - JAX documentation
JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you...
Read more >Automatic differentiation in scientific programming with jax
We use derivatives extensively in science and engineering. Historically derivatives have been a challenge in computer programs.
Read more >Automatic Differentiation in JAX | Kaggle
As you noticed in many deep learning courses and books, backpropagation is specialized version of reverse autodiff and it is used to differentiate...
Read more >The magic behind autodiff - GitHub Pages
Tutorials on automatic differentiation and JAX. ... and how to use it in JAX, by restricting our attention to scalar values and forward-mode...
Read more >JAX: A Machine Learning Research Library - GradFUTURES
In particular, these transformations include automatic differentiation, automatic batching, end-to-end-compilation (via XLA), and even ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Now I was able to implement ex10.py using JAX:
There is some overhead on autodiff which I think can be improved by learning JAX better but it’s not huge. Result is correct.
I think this is more of an issue on “how to use JAX”. Trying to use
vectorize
on the Hessian results inI was trying to run