Add Inference support to PyTorch Lightning
See original GitHub issueš Feature
From my current investigation, it seems TorchServe has matured and it is now the default Server used within Sagemaker AWS PyTorch Image.
Until now, PyTorch Lightning has offered very limited support for serving and the transition from training to production has been quite challenging for many users.
This issue is meant to start a conversation on how PyTorch Lightning can smooth the transition between those worlds.
My proposition is the following:
Add a new hook to the LightningModule called configure_serve_handler
where users can return a BaseHandler subclass to accommodate their modelsā serving requirements.
Here is how it could look like:
class MNISTClassifier(LightningModule):
def configure_serve_handler(self):
return MNISTDigitClassifierHandler(...)
Here are some opportunities available with such an introduction:
- Sanity serve checking when starting training.
- Sanity TorchScripting when starting training.
- Auto-generate an handler.py file for the users without the pytorch lightning dependency.
- Auto-generate mar file with torchscripted model, handlers, ā¦
- Provide DataType to help users to write down serialization and deserialization layers for their models or upstream to Serve.
Furthermore, using the newly introduced Lightning App, it would be possible to provide even more to PyTorch Lightning users by automating all those steps and then serving their model to any services of their choice.
Motivation
Pitch
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ā”
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @borda @tchaton @justusschock @awaelchli @carmocca @ananthsub @ninginthecloud @jjenniferdai @rohitgr7
Issue Analytics
- State:
- Created a year ago
- Reactions:3
- Comments:9 (7 by maintainers)
I agree with @justusschock in that going into seralization of data types etc is a big can of worms that should stay out of PL.
To me, the key here is evaluating whether thereās a direct interaction between the
Trainer
and serving. The top post proposes sanity checking steps which could be useful in specific cases but are very niche. How often does serving performance inform the training procedure sort of like a validation mechanism? If itās just for performance tracking, then this can be done in parallel to the training procedure by evaluating checkpoints along the way.My key concern is, what do we gain from an API like:
vs
There can be further iterations to this, like a mixin that would tie together the PL module with the serving. But this still keeps it out of the
LightningModule
interface which atm is meant to act as just a contract with the Trainer.What I think is very important is that we ease the transition OUT of a trained PL module into nn.Modules which is what most of the ecosystem projects understand.
(note: these are my thoughts after discussing with Thomas offline for a while š)
TBH, I donāt think we should add this.
The main code lives in the Handler class (all the pre- and postprocessing) and people would anyway have to write this class. So we are not saving any boilerplate code there.
Also it is not so easy to eliminate the PL dependency in most cases, as people often intermingle their LM and the Models responsibilities and for it to work properly (also in serving) you often need both.
TLDR: If you really want to provide a benefit there, IMO you need to be opinionated based on the task, which is what flash is doing and which also is why I would leave it there.