[rllib] tensorflow models don't support object oriented treatment of variables
See original GitHub issueDescribe the problem
With TF 2.0 on the horizon, tensorflow models in rllib should not exclusively depend on variable scopes to reuse variables. Similarly, graph collections should be avoided in favor of explicitly passing variables instances. Both variable scopes and collections will be deprecated in TF 2.0.
The new (and imo, better) way of handling variable sharing is through reusing tf.Variable
instances and to organize variables as attributes of some container, i.e., modules, similarly to how pytorch handles variables.
Currently, a significant portion of the TF 1.0 backend has already shifted to this approach, implementing all the necessary tools to maintain support for variable_scopes and tf.get_variable
. The public API might be sufficient for ray.rllib
to support both approach until support for TF 1.0 is dropped, e.g., tf.function
/tf.defun
, tf.wrap_function
.
This could probably be hacked together with relatively few number of lines of code, but, with a proper refactoring, I suspect that a lot of the distinctions/restrictions caused by tf model vs. pytorch model can be avoided.
On top of improving the pytorch/tf compatibility and future-proofing for TF 2.0, properly implementing this will likely lead to simpler and more readable code which would benefit new users and make implementation of future algorithms less prone to bugs.
Things that would have to be done (in no particular order):
- Stop the use of variable scope to reuse variables across algorithms.
- Refactor the low-level API to allow complete abstraction of the TF/pytorch backend, i.e., reduce the number of TF or pytorch specific (sub-)classes to one for each.
- Object oriented use of
rllib.models.model.Model
across all implementations - (maybe?) Handle TF variables serializing/checkpointing using the tensorflow tracking API instead of the in-house solution. See the relevant TF tools for handling variables
tf.Module
andAutoTrackable
(formally AutoCheckpointable)
If there is any interest in this, I am happy to write up a more detailed design proposal and help with its possible implementation.
Issue Analytics
- State:
- Created 5 years ago
- Reactions:4
- Comments:28 (28 by maintainers)
Top GitHub Comments
In the long run, I would strongly encourage a full switch to the TF 2.0 way. It will be easy to support TF 1.x with a TF 2.0 design but I don’t see how the other way around will be possible. Though, I agree with you that a full switch is not necessary yet, but now is probably a good time to start thinking about what ray + TF 2.0 will look like to make sure that 1) new classes play nice together and 2) switching requires as little refactoring as possible. In the long run, I would expect the code base will be simpler to maintain (eager/TF2 encourages better coding practices) and more likely to be compatible with future TF 2.x features.
All that being said, there is no need to make the jump all at once. I am happy focusing on solving this issue (i.e., OO tensorflow variables) before considering broader TF 2.0 support.
I’m thinking along the lines of the following. In a perfect world, we can have all of the tensor operations depend on operator overloading, allowing the agent/policy/loss code to take TF tensors (from a TF model) or PyTorch tensors (from a pytorch model) as input.
Of course, the reality is that the API differences make this impossible but with a TF 2.0 approach, these differences are going to be minor and mostly limited to different naming conventions and variable treatment. This opens up a lot of opportunities for code re-use that wouldn’t normally be considered before. However, the devil is in the details. Better code re-use could be impossible or add an undesirable amount of complexity so maybe I won’t have anything better to propose after diving in!
Naturally, I don’t intend to implement anything in that regards until I have a design worked out that you all feel is worth the effort. (As a side note, with rllib growing in complexity and the release of TF 2.0 close by, I feel like this is a good time to make sure rllib’s foundation is as flexible and easy to maintain as possible.)
@ericl That is essentially correct! The only potential issues that come with dropping variable scopes is in ensuring proper checkpointing behavior but all the necessary code to handle that is already in tensorflow (easiest done by sub-classing the appropriate class).
The concept of a
Model
that implements amodel.forward()
API is somewhat tangential to the handling of variables. Until a light API exists for handling models and their variable creation, like an updated (and therefore simplified) sonnet 1.0, the best option might be to define the simpest API that fits rllib’s needs, leaving the custom model sub-class the responsibility of creating variables only when desired.Consider this bad custom model class which improperly save variables as attributes:
A TF 1.0 veteran might feel powerless seeing this, but for any other programmer familiar with object-oriented programming, this is completely intuitive behavior. While some helpful super-class which manages this automatically would be nice, I don’t think a feature complete, complex
Model
API is required.I’ve been swamped recently but I’ll have a bit more time in the coming weeks to write up a concrete design we can iterate on. hopefully what I just haphazardly tried to say will be clear then!