2nd order derivs wants insane amount of ram
See original GitHub issueimport jax
import jax.numpy as np
import numpy as onp
def E_fn(conf):
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dxdydz = np.power(ri - rj, 2)
    dij = np.sqrt(np.sum(dxdydz, axis=-1))
    return np.sum(dij)
dE_dx_fn = jax.jacrev(E_fn, argnums=(0,))
d2E_dx2_fn = jax.jacfwd(dE_dx_fn, argnums=(0,))
d2E_dx2_fn(onp.random.rand(2483, 3))
Results in:
RuntimeError: Resource exhausted: Out of memory while trying to allocate 551102853132 bytes.
This happens on both CPU and GPU.
There’s no reason this calculation should require 551GB’s worth of ram. The explicit hessian is “only” (24833)^24 bytes=221 MB
Issue Analytics
- State:
- Created 4 years ago
- Comments:8 (8 by maintainers)
 Top Results From Across the Web
Top Results From Across the Web
Big derivatives | Counting unique components
Counting the number of unique components in high-order derivatives of functions of many variables.
Read more >Ask HN: Higher order derivatives in everyday life?
I was wondering if hackernews had any other interesting examples of higher order derivatives that one might encounter in everyday life.
Read more >Massive amount of memory (RAM) required for solve
Exceeding 2 million nodes is best to avoid on a structural model. The model I am working on now was just over 1...
Read more >dask - Applying a function over a large dataframe which is ...
2 ) Why Dask's workers can't dump the data they computed for each partition to the disk in order to release the RAM?...
Read more >Efficient Way of Taking First and Second Order Derivatives of ...
I've done a lot of digging and I'm aware of the gradient function, ... Edit: Oops, didn't see you wanted the second order...
Read more > Top Related Medium Post
Top Related Medium Post
No results found
 Top Related StackOverflow Question
Top Related StackOverflow Question
No results found
 Troubleshoot Live Code
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free Top Related Reddit Thread
Top Related Reddit Thread
No results found
 Top Related Hackernoon Post
Top Related Hackernoon Post
No results found
 Top Related Tweet
Top Related Tweet
No results found
 Top Related Dev.to Post
Top Related Dev.to Post
No results found
 Top Related Hashnode Post
Top Related Hashnode Post
No results found

Actually under a jit we never run things op-by-op (though a version of Jax from 2017 did things that way); the values propagated through the python code are abstract ones, basically representing a set of possible arrays, and those aren’t backed by any float values. The abstract values just store shape and dtype, and so they don’t take much memory or cause any FLOPs.
The more I think about it, the more I think XLA should be doing this rewrite optimization for us (and by extension jit should be doing it for you). I’ll raise it with the XLA folks and see what they think. In the worst case, if for some unforeseen reason XLA can’t do this optimization, this is a place where custom ops could help.
This is an extremely simplified (but representative) repro of a much more complicated set of non-truncated potentials that can’t always be reduced using the law of cosines (but it’s a nice trick to compute the Gramian)