add gradient clipping to `create_supervised_trainer()`
See original GitHub issueIt would be good to add gradient clipping to the trainers created by create_supervised_trainer
. This is already provided by torch.nn.utils.clip_grad_norm_
.
One possible implementation could be:
import math
from torch.nn.utils import clip_grad_norm_
def create_supervised_trainer(model, optimizer, loss_fn,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
gradient_clip=math.inf):
"""
Factory function for creating a trainer for supervised models.
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
gradient_clip (float, optional): value to use to clip gradients.
Note: `engine.state.output` for this engine is the loss of the processed batch.
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
clip_grad_norm_(model.parameters(), gradient_clip)
optimizer.step()
return loss.item()
return Engine(_update)
Issue Analytics
- State:
- Created 5 years ago
- Comments:6 (1 by maintainers)
Top Results From Across the Web
Effective Training Techniques - PyTorch Lightning
Gradient clipping can be enabled to avoid exploding gradients. ... Use default in trainer construction trainer = Trainer() tuner = Tuner(trainer) # Invoke ......
Read more >Understanding Gradient Clipping (and How It Can Fix ...
Gradient Clipping is a method where the error derivative is changed or clipped to a threshold during backward propagation through the network, and...
Read more >Gradient Editing On The Fly in Deep Neural Networks
The corresponding output contains the target label, acting as the correct answer to guide the training process. Model training aims to generate ......
Read more >Introduction to Gradient Clipping Techniques with Tensorflow
It is also the easiest and most popular way to build neural networks. However, you can still apply gradient clipping if you are...
Read more >Machine Learning - Programming Differential Privacy
Next, let's perform a single step of gradient descent. We can apply the gradient function to a single example from our training data,...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@AntoinePrv I think it would be more simple to write custom processing function instead of custom events.
Sorry, I missed that one. I had the same doubts w.r.t. moving it to
contrib.engines
. My point against doing it is that the code would be so similar to the one increate_supervised_trainer
. In any case, you are driving here.