GradsScalarHandler logs 0 gradients if default update function is used
See original GitHub issue🐛 Bug description
Logging the gradients per epoch / iteration is a useful way to debug an under-performing model. Ignite provides an easy-to-use tensorboard_logger handler, an example accessible from ignite.contrib.handlers.tensorboard_logger.GradsScalarHandler
. However, the default update
function used by Engines generated by create_supervised_trainer
zero the gradients before terminating, causing the handler to log zeroed out gradients all the time.
Steps to reproduce: My code is too complicated at the momemt to provide a clear insight, and I am limited by time to provide a minimal (not-)working example, so I will provide abstracted steps.
- Generate an
Engine
/DeterministicEngine
on an arbitrary problem by thecreate_supervised_trainer
method. - Establish a
TensorboardLogger
and a . - Attach a
GradsScalarHandler
/ your choice of a gradient logger. Also log the training loss or some other metric. - Start the training, check tensorboard and see the constant-0 gradient norms / gradients, despite the losses/metrics implying some sort of improvement/learning takes place.
Solution proposal
Preserving the gradients until epoch end is tricky, but not required for my purposes. If we are OK with using Events.ITERATION_COMPLETED
as a cue to log gradients, then we can simply modify the default update
functions as follows:
(assuming engine.state.iteration
counts from 1).
def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
if (engine.state.iteration + 1) % gradient_accumulation_steps == 0:
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
loss.backward()
if engine.state.iteration % gradient_accumulation_steps == 0:
optimizer.step()
return output_transform(x, y, y_pred, loss)
This way, upon completion of update
and at the moment of Events.ITERATION_COMPLETED
firing, there will be some non-zero gradients available to be logged.
Environment (latest version of Ignite still has the same bug)
- PyTorch Version (e.g., 1.4): 1.10.1
- Ignite Version (e.g., 0.3.0): 0.4.7
- OS (e.g., Linux):
- How you installed Ignite (
conda
,pip
, source): conda - Python version: 3.9.7
- Any other relevant information:
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:5 (2 by maintainers)
I was wrong actually, it should be
(engine.state.iteration - 1)
. I tested for two cases (gradient_accumulation_steps = 1
andgradient_accumulation_steps = 3
). The behavior is as follows.gradient_accumulation_steps = 1
:gradient_accumulation_steps = 3
:This way we both accumulate gradients in the desired manner and they are not flushed once iteration is finished, allowing for logging. The PR is on the way!
@egaznep in the beginning we implemented it in a similar way and there was a bug in the implementation and we switched to the current implementation. We have to check carefully your solution to ensure that it works perfectly. I haven’t yet checked it in details just giving a bit of context.