question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Could LBFGS work with pytrees?

See original GitHub issue

Similar to this previous issue (https://github.com/blackjax-devs/blackjax/issues/214), would it be possible to modify minimize_lbfgs to also work with the initial guess x0 being something more flexible like a tuple of arrays? That would be really handy so that I could use the same logprob function to draw samples and to find the (local) max of the posterior.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:9 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
ahwilliacommented, Jun 24, 2022

Happy to give it a shot. I am somewhat new to jax just learned that jax.flatten_util.ravel_pytree exists. I’m not sure I understand the cases where vmap(unravel_fn)(...) is needed though.

Is the following sketch what you two are thinking?

def minimize_lbfgs(func, x0, **kwargs):
    if isinstance(x0, jax.numpy.DeviceArray):
        return _minimize_lbfgs(func, x0, **kwargs)
    else:
        x0_raveled, unravel_fn = ravel_pytree(init_params)
        return _minimize_lbfgs(lambda x: func(unravel_fn(x)), x0_raveled, **kwargs)

Where minimize_lbfgs -> _minimize_lfbgs.

0reactions
rloufcommented, Jun 29, 2022

Closing as the corresponding PR was merged. Thank you!

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.scipy.optimize.minimize
Optimization results may differ from SciPy due to differences in the line search implementation. minimize supports jit() compilation.
Read more >
jaxopt.GradientDescent — JAXopt 0.5.5 documentation
Initialize the solver state. Parameters. init_params ( Any ) – pytree containing the initial parameters. *args – additional positional ...
Read more >
tree-math - Python Package Health Analysis
You can install it from PyPI: pip install tree-math . ... Vector objects are pytrees themselves, which means the are compatible with JAX...
Read more >
BFGS vs L-BFGS -- how different are they really?
BFGS and LBFGS can theoretically converge to completely different solutions (if there are multiple local minima) with different convergence ...
Read more >
TensorFlow Probability on JAX
TFP on JAX also works with nested structures of JAX objects, ... They are also registered as JAX Pytrees, so they can be...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found