question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Enable gradients in validation_step

See original GitHub issue

🚀 Feature

Enable gradient calculations during evaluation loops.

Motivation

Some loss functions require the gradients of the outputs with respect to the inputs. For example, a physics informed neural network uses these gradients in a differential equation as its loss function.

Pitch

Add a set_grad_enabled flag to validation step to make the following possible for learning the cosine function, for example:

def validation_step(self, 
    batch: torch.Tensor, 
    batch_idx: int, 
    set_grad_enabled=True
) -> dict[str, torch.Tensor]:
    
    x, t = batch
    
    u_hat = self.forward(x, t)
    
    dydx = torch.autograd.grad(
              u_hat, 
              x,
              grad_outputs=torch.ones_like(u_hat),
              create_graph=True
          )
    
    physics_loss = torch.sin(x) - dydx
    
    return {"loss": physics_loss}

Alternatives

I tried simply adding with torch.set_grad_enabled(True) to my code but of course that didn’t work.

Additional context

Physics Informed Neural Networks are a recently introduced architecture for learning differential equations by embedding them in the loss function of a neural network. The key innovation involves repurposing auto-differentiation machinery to obtain derivatives of network outputs wrt its inputs, then plugging those into the residual form of a differential equation. The network learns accurate derivatives by minimizing this residual.

PINNs paper: https://faculty.sites.iastate.edu/hliu/files/inline-files/PINN_RPK_2019_1.pdf


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.

Issue Analytics

  • State:closed
  • Created 10 months ago
  • Comments:7 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
jharris1679commented, Nov 23, 2022

Ahh in addition to .requires_grad_() I had to put everything in my validation step under with torch.enable_grad():, and now it works! Thanks again.

0reactions
GeoffNNcommented, Dec 9, 2022

Have you tried using functorch to compute the partials ? It’s much easier to use.

I have an example script here. https://github.com/GeoffNN/deeponet-fno/blob/main/src/burgers/pytorch_deeponet.py

Read more comments on GitHub >

github_iconTop Results From Across the Web

LightningModule - PyTorch Lightning - Read the Docs
At the end of validation, the model goes back to training mode and gradients are enabled. validation_step_end. LightningModule.validation_step_end(* ...
Read more >
Calculate gradient of validation error w.r.t inputs using Keras ...
I'm trying to calculate the gradient of E w.r.t x directly. An alternative approach would be to calculate the gradient of E w.r.t...
Read more >
Why to use validation during training, if we are using no_grad?
The validation step is applied to get a proxy signal about the model performance on the test data (which cannot be used during...
Read more >
How To Enable Gradients In Bootstrap - YouTube
We'll look at what you can use out of the box, and what additional gradient options you can enable through your template's Sass...
Read more >
Predicting Readmission or Death After Discharge From the ...
A gradient boosted ML model was developed and validated on ... Validation Step, Area Under the Receiver Operating Characteristic Curve ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found