Refactor Pytorch `model.generate` method to work on TPU
See original GitHub issueFeature request
Refactor PT version of the method model.generate
for text generating models to make it compatible with XLA and speed up inference on TPU.
Motivation
Right now, model.generate
on PT is extremely slow on TPU compared to CPU and GPU. This is probably due to the fact that some operations done in the PT version of model.generate
are not XLA compatible, and thus the generation process falls back on CPU. This makes inference on TPU infeasible. A major refactoring work has already been done on its TF counterpart, so it would be nice to have the PT version working as well.
A more in-depth discussion with @gante took place in #12322 and on this huggingface discussion.
Your contribution
If there is some interest from the HF team, I can definitely assist during the work.
Issue Analytics
- State:
- Created a year ago
- Reactions:4
- Comments:9 (5 by maintainers)
Top Results From Across the Web
Running PyTorch on TPU: a bag of tricks | by Zahar Chikishev
1) DataParallel holds copies of the model object (one per TPU device), which are kept synchronized with identical weights. · 2) DataParallel ...
Read more >PyTorch TPU starter - DeBERTa-v3-large (training) | Kaggle
Let's start training! To do so, we start by initializing the model. We use the xmp.MpModelWrapper provided by PyTorch XLA to save memory...
Read more >Tensor Processing Unit (TPU) - PyTorch Lightning
All TPU VMs in a Pod setup are required to access the model code and data. One easy way to achieve this is...
Read more >Getting Started with PyTorch on Cloud TPUs - Google Colab
PyTorch /XLA is a package that lets PyTorch connect to Cloud TPUs and use TPU cores as devices. Colab provides a free Cloud...
Read more >huggingface/transformers: Model versioning, TensorFlow ...
... new scripts, refactor of the generate method Model versioning We host ... examples/docs: caveat that PL examples don't work on TPU #8309 ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Added to my
generate
task queue 👍@divyanshuaggarwal it would be part of
transformers
!This is not a prioritized feature as you can already use TPUs for generation in Flax and TensorFlow. Since you can easily convert a model from one framework to the other, there is an easy workaround 😃