Enable gradients in validation_step
See original GitHub issue🚀 Feature
Enable gradient calculations during evaluation loops.
Motivation
Some loss functions require the gradients of the outputs with respect to the inputs. For example, a physics informed neural network uses these gradients in a differential equation as its loss function.
Pitch
Add a set_grad_enabled
flag to validation step to make the following possible for learning the cosine function, for example:
def validation_step(self,
batch: torch.Tensor,
batch_idx: int,
set_grad_enabled=True
) -> dict[str, torch.Tensor]:
x, t = batch
u_hat = self.forward(x, t)
dydx = torch.autograd.grad(
u_hat,
x,
grad_outputs=torch.ones_like(u_hat),
create_graph=True
)
physics_loss = torch.sin(x) - dydx
return {"loss": physics_loss}
Alternatives
I tried simply adding with torch.set_grad_enabled(True)
to my code but of course that didn’t work.
Additional context
Physics Informed Neural Networks are a recently introduced architecture for learning differential equations by embedding them in the loss function of a neural network. The key innovation involves repurposing auto-differentiation machinery to obtain derivatives of network outputs wrt its inputs, then plugging those into the residual form of a differential equation. The network learns accurate derivatives by minimizing this residual.
PINNs paper: https://faculty.sites.iastate.edu/hliu/files/inline-files/PINN_RPK_2019_1.pdf
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.
Issue Analytics
- State:
- Created 10 months ago
- Comments:7 (2 by maintainers)
Top GitHub Comments
Ahh in addition to
.requires_grad_()
I had to put everything in my validation step underwith torch.enable_grad():
, and now it works! Thanks again.Have you tried using functorch to compute the partials ? It’s much easier to use.
I have an example script here. https://github.com/GeoffNN/deeponet-fno/blob/main/src/burgers/pytorch_deeponet.py