Question about step execution time
See original GitHub issueHi,
I was wondering if it was intended that the diffopt.step(loss)
command takes an increasing time to execute when the number of state saved is increasing ?
The step time should be constant as we perform a back-propagation on only one state, while the computation of a meta-gradient, at the end of the inner loops, should be longer as the number of states saved increase, right ?
t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
print(len(fmodel._fast_params),"step", time.process_time()-t)
(...)
t = time.process_time()
val_loss.backward()
print("meta", time.process_time()-t)
This code gives me :
2 step 0.513532734
3 step 1.2003996619999988
4 step 1.4545214800000004
5 step 1.6974909480000004
6 step 1.9400910080000013
7 step 2.1659202289999975
meta 1.4035290689999975
2 step 0.4082054819999996
3 step 1.236462700999997
4 step 1.4650358509999997
5 step 1.702944763999998
6 step 1.9239114150000027
meta 1.1481397730000005
EDIT :
After looking into the DifferentiableOptimizer code, it’s seems that what’s causing this slow down is the building of the graph of the derivative in
all_grads = _torch.autograd.grad(
loss,
grad_targets,
create_graph=self._track_higher_grads,
allow_unused=True # boo
)
I’m not really familiar with the way autograd handle this but it seems the whole graph is computed at each call. Isn’t it possible to keep the previous graph and extend it as the gradient tape expand, with the new states ?
Issue Analytics
- State:
- Created 4 years ago
- Comments:7 (5 by maintainers)
Top GitHub Comments
Sorry for the delay in looking into this. I believe this problem is linked to something in core pytorch, as flagged in this issue: https://github.com/pytorch/pytorch/issues/12635. Someone is looking into it AFAIK, so I’ll report back if progress is made there.
I can repro your issue on CPU using the script you provided. Thanks for narrowing it down to the step where the higher order graph is created. Running your code with
track_higher_grads=False
in the diffopt creation yields a constant step time.I’m not entirely sure why this is happening: the backward graph is bigger I suppose, but it should already be partly created. I’m not sure what is causing a bigger graph to be created from scratch. I’ll have a think about this later this week, but I think we’d need to narrow it down a bit and/or try to figure out what the simplest repro of this is that doesn’t have deep
higher
dependencies, so I can go to the pytorch team for help if needed.