Adding GAN support to TensorGraph
See original GitHub issueI think our current model with TensorGraph
can’t nicely support GANs. In order to train GANs, you need to train both a discriminator D
and a generator G
. The training of these two models is threaded together, with each updated for a few minibatches at a time. While updating the generator, the discriminator weights should be frozen. While updating the discriminator, the generator weights need to be frozen.
I have an idea for a new TensorGraph
API to support this sort of training. We add a new layers
kwarg to fit
:
model.fit(feed_generator, layers=None)
The idea is that fit
can be instructed to train only some layers at a time. Note that there’s now a semantic change:
Layer
is responsible for tracking the variables it introduces.layer.variables
should be a list of all trainable variables created by thatLayer
.
@peastman and @lilleswing Does this basic design make sense?
Hat tip to @enf for explaining GAN training to me in detail!
Issue Analytics
- State:
- Created 6 years ago
- Comments:34 (30 by maintainers)
Top GitHub Comments
How about this as a candidate API:
So a submodel specifies a list of layers, a loss function, and an optimizer. When you perform fitting, you can optionally specify a submodel to do the fitting on. It then uses a different
train_op
, but everything else works exactly the same.In the abstract of SeqGAN, they also say the following, which could be very important in my view:
I think that the model that should be used really depends on the kind of data we want to generate. Maybe would it be better to focus on a specific chemical application first?