question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

2nd order derivs wants insane amount of ram

See original GitHub issue
import jax
import jax.numpy as np
import numpy as onp

def E_fn(conf):
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dxdydz = np.power(ri - rj, 2)
    dij = np.sqrt(np.sum(dxdydz, axis=-1))
    return np.sum(dij)

dE_dx_fn = jax.jacrev(E_fn, argnums=(0,))
d2E_dx2_fn = jax.jacfwd(dE_dx_fn, argnums=(0,))

d2E_dx2_fn(onp.random.rand(2483, 3))

Results in:

RuntimeError: Resource exhausted: Out of memory while trying to allocate 551102853132 bytes.

This happens on both CPU and GPU.

There’s no reason this calculation should require 551GB’s worth of ram. The explicit hessian is “only” (24833)^24 bytes=221 MB

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, May 30, 2019

Actually under a jit we never run things op-by-op (though a version of Jax from 2017 did things that way); the values propagated through the python code are abstract ones, basically representing a set of possible arrays, and those aren’t backed by any float values. The abstract values just store shape and dtype, and so they don’t take much memory or cause any FLOPs.

The more I think about it, the more I think XLA should be doing this rewrite optimization for us (and by extension jit should be doing it for you). I’ll raise it with the XLA folks and see what they think. In the worst case, if for some unforeseen reason XLA can’t do this optimization, this is a place where custom ops could help.

0reactions
proteneercommented, May 30, 2019

This is an extremely simplified (but representative) repro of a much more complicated set of non-truncated potentials that can’t always be reduced using the law of cosines (but it’s a nice trick to compute the Gramian)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Big derivatives | Counting unique components
Counting the number of unique components in high-order derivatives of functions of many variables.
Read more >
Ask HN: Higher order derivatives in everyday life?
I was wondering if hackernews had any other interesting examples of higher order derivatives that one might encounter in everyday life.
Read more >
Massive amount of memory (RAM) required for solve
Exceeding 2 million nodes is best to avoid on a structural model. The model I am working on now was just over 1...
Read more >
dask - Applying a function over a large dataframe which is ...
2 ) Why Dask's workers can't dump the data they computed for each partition to the disk in order to release the RAM?...
Read more >
Efficient Way of Taking First and Second Order Derivatives of ...
I've done a lot of digging and I'm aware of the gradient function, ... Edit: Oops, didn't see you wanted the second order...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found