0.6.3 introduced a circular dependency with orbax
See original GitHub issueProvide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: - Python version:
- GPU/TPU model and memory:
- CUDA version (if applicable):
Problem you have encountered:
Flax 0.6.3 added a dependency on orbax, which has a dependency on flax. This is causing https://github.com/tensorflow/tfjs/issues/7159 in the TensorFlow.js repository. TFJS resolves pypi packages using Bazel, which does not support circular dependencies.
Was this change intentional? If so, I can file a bug with rules_python instead, although last time this kind of circular dependency issue arose, it was determined to be a bug in the downstream package. I’m not sure if that true in this case, though.
What you expected to happen:
No circular dependency.
Logs, error messages, etc:
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
Issue Analytics
- State:
- Created 9 months ago
- Reactions:1
- Comments:6
Top GitHub Comments
For the future, here are some ideas:
seralization.py
, this makes sense for a check-pointing library. On the Flax side we add some wrappers.flax-serialization
library which both libraries can depend on to avoid the circular dependency.We have to maintain control of traversals / flattening nested-dict, struct dataclass, etc. As we use those for more things than just checkpoints. One hope is to enhance JAX pytree calls with the notion of “paths” such that we could remove our hacky state-dict registry altogether - but this is still in discussion.