question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Model Verification in Trainer

See original GitHub issue

🚀 Feature

Verifies that the provided model code does not mix up data across the batch dimension. We do this by setting the loss to be something trivial (e.g. the sum of all outputs of example i), running the backward pass all the way to the input, and ensuring that we only get a non-zero gradient on the i-th input.

Motivation

First of all, I would like to say thank you for the fantastic work being done on this project. Recently, I was working on a side project that has almost the exact same goal as this one, which I used as motivation to learn more about PyTorch and how to make Deep Learning easier. Clearly, this project is a lot more thought-out than mine :^), but I wanted to see if there were any ideas I developed independently that might be useful in this project.

One of the most useful utils I’ve implemented is a verification step before the model runs. In my project, this verification step performs checks such as:

  • ensuring data is not mixed across the batch dimension
  • ensuring the model can overfit a single example
  • ensuring that all layers of the model are training (or selected layers are properly frozen)

Since I am very new to this project, I thought that the first bullet point might be a good place to start.

Pitch

Given the introductory example in the documention, assume we had written some poor tensor operations in our model like so:

class BadModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        ###
        # x = x.view(batch_size, -1)
        ###
        x = x.view(-1, 1, 56, 56)
        x = x.permute(1, 0, 3, 2)
        x = x.reshape((batch_size, -1))
        ###

        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x

When we start to train our model, everything begins training smoothly. However, this code is clearly wrong - we are crossing image data from separate datapoints in our batch.

It would be helpful if Lightning gave us a warning if this has happened. For example:

def check_batch_dimension(model, loader, optimizer, test_val=2):
    model.eval()
    torch.set_grad_enabled(True)
    data, _ = next(iter(loader))
    optimizer.zero_grad()
    data.requires_grad_()

    output = model(data)
    loss = output[test_val].sum()
    loss.backward()

    error_msg = "Your model is mixing up data across the batch dimension!"
    assert loss != 0
    assert (data.grad[test_val] != 0).any(), error_msg
    assert (data.grad[:test_val] == 0.).all() and (data.grad[test_val+1:] == 0.).all(), error_msg

This function verifies that only a single datapoint in the batch should have a nonzero gradient. This check has saved me countless times from running a poorly written model. 😃

Implementation-wise, I am looking for any advice on whether this is a useful effort, whether it fits into the intended goals of Lightning, and what are possible difficulties that may arise.

Alternatives

It is clear that the feature as it stands will not work for all models, as some variants of LSTMs and such use a different dimension as its batch dimension (maybe this can be a parameter). There also might be issues if the batch is split up somewhere - I’m not quite certain how everything in this project works, particularly around gradient accumulation.

However, I would expect that this would be useful in almost all models. I advocate this being a default warning, but also allowing well-intentioned users to simply pass some sort of flag to disable this verification step.

I also realize there needs to be some cleanup after this step to reset the model to its previous state. Any insights here would be great as well.

Additional context

None

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:2
  • Comments:19 (17 by maintainers)

github_iconTop GitHub Comments

4reactions
williamFalconcommented, Aug 17, 2020

this is prime for a callback

2reactions
awaelchlicommented, Aug 17, 2020

I have actually implemented this in a separate class myself to verify my models and used it many times. It is a great sanity test. Maybe I can send a PR or Google Colab and @TylerYep can help me test it. We can also come up with more verification tests.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Validate and test a model (intermediate) - PyTorch Lightning
There are generally 2 stages of evaluation: validation and testing. To some degree they serve the same purpose, to make sure models works...
Read more >
Trainer - Hugging Face
The Trainer class is optimized for Transformers models and can have surprising behaviors when you use it on other models. When using it...
Read more >
How can I perform only validation without training · Issue #2481
I want to load the .ckpt to perform validation without training, ... trainer.test(model, ckpt_path="path/to/m10-f1_1=0.8737.ckpt", ...
Read more >
Does the Pytorch Lightning Trainer use the validation data to ...
Now my question is, whether the validation data has any influence on the optimization of the model? I have been playing around with...
Read more >
Understanding ML In Production: Model Evaluation and ...
Inputs are the model which the Trainer-Tuner (like we explained on the ... or Tensorflow.js, use them as validation environments of course.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found