Save/restore for TensorGraph
See original GitHub issueIn the past, our models have generally had a persistent Session object containing all the variables for the model. They automatically wrote out checkpoints during training, and provided a restore()
method to load the variables from the most recent checkpoint into memory.
TensorGraph works differently. It creates a new Session every time you call fit()
or predict()
, loads the variables from the latest checkpoint, then throws away the Session before returning. I find this behavior has a couple of problems.
First, it’s really slow. If you want to run prediction thousands of times, that means it will load the variables from disk thousands of times. This is why for both RL and MAML I’ve basically had to subvert the design of TensorGraph. I use it to build the Tensorflow graph, but when I need to do any calculations I ignore the corresponding parts of TensorGraph, pull out the needed internal fields, and do the calculations directly.
It also is inflexible, since it enforces that prediction can only ever be done based on the variables in the latest checkpoint file, no others. A good example of where this causes problems is MAML. It uses TensorGraph to define the model, but the actual optimization is done with a different optimizer and a different loss function. That produces a model (which gets saved to disk) that is designed to be easy to train, but isn’t optimized for any particular task. To use it for prediction, you first do a few steps of gradient descent to produce a fine tuned version of the model (which does not get saved to disk), and do the prediction based on that. But I can’t use any of TensorGraph’s prediction methods for it, because it would just throw away the tuned variables and replace them with the generic ones from disk.
I think we should consider changing this behavior.
Issue Analytics
- State:
- Created 6 years ago
- Comments:11 (9 by maintainers)
Top GitHub Comments
I’d suggest something roughly identical to what’s currently in A3C. When you first create the object, it creates a persistent Session. There’s a
restore()
method to load the most recent checkpoint from disk. I also included arestore
argument tofit()
that you can use to have it automatically load the latest checkpoint and continue training from where it left off.FWIW, the first thing I did was check whether 1. is possible.