question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Hessian diagonal computation

See original GitHub issue

Hello,

I’ve been playing around with autograd and I’m having a blast. However I’m having some difficulty with extracting the diagonal of the Hessian.

This is my current code:

from autograd import hessian
import autograd.numpy as np 

y_pred = np.array([
    [1, 0, 0, 0, 0],
    [0, 1, 0, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1]
], dtype=float)

weights = np.array([1, 1, 1, 1, 1], dtype=float)

def softmax(x, axis=1):
    z = np.exp(x)
    return z / np.sum(z, axis=axis, keepdims=True)

def loss(y_pred):
    
    y_true = np.array([
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1]
    ], dtype=float)
    
    ys = np.sum(y_true, axis=0)
    y_true = y_true / ys
    ln_p = np.log(softmax(y_pred))
    wll = np.sum(y_true * ln_p, axis=0)
    loss = -np.dot(weights, wll)
    return loss

hess = hessian(loss)(y_pred)

I understand that hessian is simply jacobian called twice and that hess is an n * p * n * p matrix. I can extract the diagonal manually and obtain my expected output which is:

[[0.24090069 0.12669198 0.12669198 0.12669198 0.12669198]
 [0.12669198 0.24090069 0.12669198 0.12669198 0.12669198]
 [0.12669198 0.12669198 0.12669198 0.24090069 0.12669198]
 [0.12669198 0.12669198 0.24090069 0.12669198 0.12669198]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]
 [0.04223066 0.04223066 0.04223066 0.04223066 0.08030023]]

I’ve checked this numerically and it’s fine. The problem is that this still requires computing the full Hessian before accessing the diagonal part, which is really expensive. Is there any better way to proceed? I think this is a common use case in machine learning optimization that could deserve a dedicated convenience function

Issue Analytics

  • State:open
  • Created 5 years ago
  • Comments:8 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
dougalmcommented, Dec 1, 2018

Unfortunately, I don’t think it’s possible to compute the diagonal of the Hessian other than by taking N separate Hessian-vector products, equivalent to instantiating the full Hessian and then taking the diagonal. People resort to all sorts of tricks to estimate the trace of the Hessian (e.g. https://arxiv.org/abs/1802.03451) precisely because it’s expensive to evaluate the diagonal.

Autograd’s elementwise_grad has a very important caveat: it only applies to functions for which the Jacobian is diagonal. All it does is a vector-Jacobian product with a vector of ones, which gives you the sum of each row of the Jacobian. If the Jacobian is diagonal, then that’s the same thing as the diagonal of the Jacobian.

That caveat was in the docstring of an earlier version of elementwise_grad. But at some point we deleted the function (because it can be misleading!) and then later we reinstated it, without the docstring. I just added the caveat back in. Sorry for the confusion.

1reaction
jermwattcommented, Dec 2, 2018

How stupid of a work-around (not invoking jacobian elementwise_grad / trying to avoid the diagonal Jacobian restriction / trying to avoid computing the second cross-partials) would it be to loop over the input arguments and create the pure second partials one-at-a-time using grad? Supposing the input arguments are unpacked - and computing the j^{th} pure second partial as grad(grad(g,(j)),(j)) ?

I’m assuming pretty stupid, but for example

from autograd import grad

# a test function
g = lambda w0,w1: 5*w0**2 + 7*w1**2 + w0*w1

# a test point
w0 = 2.0; w1 = 2.0

# construct second pure partials one at-a-time
second_pure_partials = [grad(grad(g,(j)),(j)) for j in range(num_partials)]

# evaluate second pure partials one at-a-time
diag_hess = lambda w0,w1: [part(w0,w1) for part in second_pure_partials]
print ('pure second derivatives at test point = ' + str(diag_hess(w0,w1)))
     
`pure second derivatives at test point = [10.0, 14.0]`

compare to the “incorrect” answer provided by composing elementwise_grad in an attempt to get at the pure partials


bad_hess_diag = elementwise_grad(elementwise_grad(g,(0,1)),(0,1))
print ('"incorrect" pure second derivatives at test point = ' + str(bad_hess_diag(w0,w1)))

`"incorrect" pure second derivatives at test point = (array(11.), array(15.))`
Read more comments on GitHub >

github_iconTop Results From Across the Web

Computing the diagonal elements of a Hessian #3801 - GitHub
Hi all, I would like to use Jax to compute the diagonal elelments of a Hessian matrix, i.e second partial derivatives \partial y^2 ......
Read more >
Autodiff: calculate just the diagonal of the Hessian
Given a function f(x) from R^2 → R, is there an efficient way to calculate just the diagonal entries of the Hessian matrix...
Read more >
HesScale: Scalable Computation of Hessian Diagonals - arXiv
In this paper, we develop HesScale, a scalable approach to approximating the diagonal of the Hessian matrix, to incorporate second-order ...
Read more >
APPLYING A DIAGONAL HESSIAN APPROXIMATION FOR ...
Here, we apply the inverse of a diagonal Hessian approximation for preconditioning, which is a physically founded approach. However, its calculation is ...
Read more >
Computing the diagonal approximation of the Hessian of the ...
Here we are going to calculate the elements along the diagonal of the Hessian matrix based on the gradient. ∂r(x,w) ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found