question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

scipy.optimize.minimize with method=SLSQP does not minimize successfully

See original GitHub issue

Hi, hopefully I’m actually doing something wrong with jax or scipy here, but…

import jax.numpy as np
import numpy as onp
from scipy.optimize import minimize


def run(np):
    def rosen(x):
        return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)

    x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
    bounds = np.array([[0.7, 1.3]] * 5)
    return minimize(rosen, x0, method='SLSQP', options={'ftol': 1e-9, 'disp': True})


print(run(onp).x)
print(run(np).x)

This example has the following output:

Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.969820110544395e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
[0.999999   0.99999821 0.9999967  0.99999373 0.9999876 ]
/Users/kratsg/.virtualenvs/pyhf/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 848.2199096679688
            Iterations: 1
            Function evaluations: 7
            Gradient evaluations: 1
[1.29999995 0.69999999 0.80000001 1.89999998 1.20000005]

where:

  • numpy: [0.999999 0.99999821 0.9999967 0.99999373 0.9999876 ]
  • jax.numpy: [1.29999995 0.69999999 0.80000001 1.89999998 1.20000005]

I do not understand why (1) the results are different and (2) why there are so few iterations for jax.numpy case. I suspect this is what’s causing differences in the minimization result.

NB: I got this example from scipy docs https://docs.scipy.org/doc/scipy/reference/tutorial/optimize.html#nelder-mead-simplex-algorithm-method-nelder-mead

/cc @lukasheinrich @matthewfeickert (affects diana-hep/pyhf#377)

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:3
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

7reactions
mattjjcommented, Jun 27, 2019

Thanks for asking this, and the beautiful runnable example!

I think the issue is just 64bit vs 32bit. JAX by default maxes out at 32bit values for ints and floats. We chose that as a default policy because a primary use case of JAX is neural network-based machine learning research. That’s different from ordinary NumPy, though, which is very happy to cast things to 64bit values. In fact, that’s why we made 32bit the system-wide maximum precision by default: because otherwise users might get annoyed by the NumPy API promoting things to 64bit values all the time, when they’re trying to stay in 32bit for neural net training!

To enable 64bit values, I ran your script like this:

$ JAX_ENABLE_X64=1 python issue936.py
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.96982011054e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
[0.999999   0.99999821 0.9999967  0.99999373 0.9999876 ]
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.96982011054e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
[0.999999   0.99999821 0.9999967  0.99999373 0.9999876 ]

When that 64bit flag is switched on, jax.numpy follow’s numpy’s precision semantics very closely, and as you can see it causes the numerics to agree here.

Another way to set that flag is by doing something like this at the top of your main .py file:

from jax.config import config
config.update("jax_enable_x64", True)

You can see a bit more in the short gotchas section of the readme, and in the gotchas notebook.

1reaction
mattjjcommented, Jul 11, 2019

I’m curious how we still managed to get 64-bit precision on the numpy side with 32b arrays?

My guess is that somewhere NumPy or scipy.optimize.minimize is promoting to 64bit precision here, even if you feed in a 32bit input. Indeed, the output in both cases (with JAX_ENABLE_X64=0) is float64, but perhaps more of the calculations are being done in 32bit precision (namely the evaluation of the objective function) when using jax.numpy without enabling 32bit values.

This experiment seems like some evidence in that direction:

import jax.numpy as np
import numpy as onp
from scipy.optimize import minimize


def run(np):
    def rosen(x):
        return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)

    x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
    bounds = np.array([[0.7, 1.3]] * 5)
    result = minimize(rosen, x0, method='SLSQP', options={'ftol': 1e-9, 'disp': True})
    print(result.x.dtype)
    print(rosen(result.x).dtype)

run(onp)
run(np)
In [1]: run issue936.py
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.96982011054e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
float64
float64
jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 848.219909668
            Iterations: 1
            Function evaluations: 7
            Gradient evaluations: 1
float64
float32  # different!
Read more comments on GitHub >

github_iconTop Results From Across the Web

scipy.optimize.minimize — SciPy v1.9.3 Manual
Method SLSQP uses Sequential Least SQuares Programming to minimize a function of several variables with any combination of bounds, equality and inequality ...
Read more >
scipy.optimize.minimize returns a solution that does not ...
minimize produces solution that does not satisfy the constraints, but the report says that optimization terminated successfully. The objective ...
Read more >
Intro to Scipy Optimization: Minimize Method - YouTube
In this video, I'll show you the bare minimum code you need to solve optimization problems using the scipy. optimize. minimize method.
Read more >
Optimization (scipy.optimize) — SciPy v0.18.1 Reference Guide
The method which requires the fewest function calls and is therefore often the fastest method to minimize functions of many variables uses the...
Read more >
scipy.optimize.minimize — SciPy v0.18.1 Reference Guide
If not given, chosen to be one of BFGS, L-BFGS-B, SLSQP, depending if the problem has constraints or bounds. jac : bool or...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found