question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

add gradient clipping to `create_supervised_trainer()`

See original GitHub issue

It would be good to add gradient clipping to the trainers created by create_supervised_trainer. This is already provided by torch.nn.utils.clip_grad_norm_.

One possible implementation could be:

import math
from torch.nn.utils import clip_grad_norm_

def create_supervised_trainer(model, optimizer, loss_fn,
                              device=None, non_blocking=False,
                              prepare_batch=_prepare_batch,
                              gradient_clip=math.inf):
    """
    Factory function for creating a trainer for supervised models.
    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        gradient_clip (float, optional): value to use to clip gradients.
    Note: `engine.state.output` for this engine is the loss of the processed batch.
    Returns:
        Engine: a trainer engine with supervised update function.
    """
    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        clip_grad_norm_(model.parameters(), gradient_clip)
        optimizer.step()
        return loss.item()

    return Engine(_update)

Issue Analytics

  • State:open
  • Created 5 years ago
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

2reactions
vfdev-5commented, Feb 1, 2019

@AntoinePrv I think it would be more simple to write custom processing function instead of custom events.

1reaction
lmarticommented, Jan 30, 2019

Sorry, I missed that one. I had the same doubts w.r.t. moving it to contrib.engines. My point against doing it is that the code would be so similar to the one in create_supervised_trainer. In any case, you are driving here.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Effective Training Techniques - PyTorch Lightning
Gradient clipping can be enabled to avoid exploding gradients. ... Use default in trainer construction trainer = Trainer() tuner = Tuner(trainer) # Invoke ......
Read more >
Understanding Gradient Clipping (and How It Can Fix ...
Gradient Clipping is a method where the error derivative is changed or clipped to a threshold during backward propagation through the network, and...
Read more >
Gradient Editing On The Fly in Deep Neural Networks
The corresponding output contains the target label, acting as the correct answer to guide the training process. Model training aims to generate ......
Read more >
Introduction to Gradient Clipping Techniques with Tensorflow
It is also the easiest and most popular way to build neural networks. However, you can still apply gradient clipping if you are...
Read more >
Machine Learning - Programming Differential Privacy
Next, let's perform a single step of gradient descent. We can apply the gradient function to a single example from our training data,...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found