scipy.optimize.minimize with method=SLSQP does not minimize successfully
See original GitHub issueHi, hopefully I’m actually doing something wrong with jax
or scipy
here, but…
import jax.numpy as np
import numpy as onp
from scipy.optimize import minimize
def run(np):
def rosen(x):
return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
bounds = np.array([[0.7, 1.3]] * 5)
return minimize(rosen, x0, method='SLSQP', options={'ftol': 1e-9, 'disp': True})
print(run(onp).x)
print(run(np).x)
This example has the following output:
Optimization terminated successfully. (Exit mode 0)
Current function value: 7.969820110544395e-11
Iterations: 29
Function evaluations: 221
Gradient evaluations: 29
[0.999999 0.99999821 0.9999967 0.99999373 0.9999876 ]
/Users/kratsg/.virtualenvs/pyhf/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully. (Exit mode 0)
Current function value: 848.2199096679688
Iterations: 1
Function evaluations: 7
Gradient evaluations: 1
[1.29999995 0.69999999 0.80000001 1.89999998 1.20000005]
where:
- numpy:
[0.999999 0.99999821 0.9999967 0.99999373 0.9999876 ]
- jax.numpy:
[1.29999995 0.69999999 0.80000001 1.89999998 1.20000005]
I do not understand why (1) the results are different and (2) why there are so few iterations for jax.numpy case. I suspect this is what’s causing differences in the minimization result.
NB: I got this example from scipy docs https://docs.scipy.org/doc/scipy/reference/tutorial/optimize.html#nelder-mead-simplex-algorithm-method-nelder-mead
/cc @lukasheinrich @matthewfeickert (affects diana-hep/pyhf#377)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:3
- Comments:6 (3 by maintainers)
Top Results From Across the Web
scipy.optimize.minimize — SciPy v1.9.3 Manual
Method SLSQP uses Sequential Least SQuares Programming to minimize a function of several variables with any combination of bounds, equality and inequality ...
Read more >scipy.optimize.minimize returns a solution that does not ...
minimize produces solution that does not satisfy the constraints, but the report says that optimization terminated successfully. The objective ...
Read more >Intro to Scipy Optimization: Minimize Method - YouTube
In this video, I'll show you the bare minimum code you need to solve optimization problems using the scipy. optimize. minimize method.
Read more >Optimization (scipy.optimize) — SciPy v0.18.1 Reference Guide
The method which requires the fewest function calls and is therefore often the fastest method to minimize functions of many variables uses the...
Read more >scipy.optimize.minimize — SciPy v0.18.1 Reference Guide
If not given, chosen to be one of BFGS, L-BFGS-B, SLSQP, depending if the problem has constraints or bounds. jac : bool or...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for asking this, and the beautiful runnable example!
I think the issue is just 64bit vs 32bit. JAX by default maxes out at 32bit values for ints and floats. We chose that as a default policy because a primary use case of JAX is neural network-based machine learning research. That’s different from ordinary NumPy, though, which is very happy to cast things to 64bit values. In fact, that’s why we made 32bit the system-wide maximum precision by default: because otherwise users might get annoyed by the NumPy API promoting things to 64bit values all the time, when they’re trying to stay in 32bit for neural net training!
To enable 64bit values, I ran your script like this:
When that 64bit flag is switched on,
jax.numpy
follow’snumpy
’s precision semantics very closely, and as you can see it causes the numerics to agree here.Another way to set that flag is by doing something like this at the top of your main .py file:
You can see a bit more in the short gotchas section of the readme, and in the gotchas notebook.
My guess is that somewhere NumPy or
scipy.optimize.minimize
is promoting to 64bit precision here, even if you feed in a 32bit input. Indeed, the output in both cases (withJAX_ENABLE_X64=0
) is float64, but perhaps more of the calculations are being done in 32bit precision (namely the evaluation of the objective function) when usingjax.numpy
without enabling 32bit values.This experiment seems like some evidence in that direction: