question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Refactor Pytorch `model.generate` method to work on TPU

See original GitHub issue

Feature request

Refactor PT version of the method model.generate for text generating models to make it compatible with XLA and speed up inference on TPU.

Motivation

Right now, model.generate on PT is extremely slow on TPU compared to CPU and GPU. This is probably due to the fact that some operations done in the PT version of model.generate are not XLA compatible, and thus the generation process falls back on CPU. This makes inference on TPU infeasible. A major refactoring work has already been done on its TF counterpart, so it would be nice to have the PT version working as well.

A more in-depth discussion with @gante took place in #12322 and on this huggingface discussion.

Your contribution

If there is some interest from the HF team, I can definitely assist during the work.

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:4
  • Comments:9 (5 by maintainers)

github_iconTop GitHub Comments

4reactions
gantecommented, Sep 28, 2022

Added to my generate task queue 👍

@divyanshuaggarwal it would be part of transformers!

0reactions
sguggercommented, Dec 12, 2022

This is not a prioritized feature as you can already use TPUs for generation in Flax and TensorFlow. Since you can easily convert a model from one framework to the other, there is an easy workaround 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

Running PyTorch on TPU: a bag of tricks | by Zahar Chikishev
1) DataParallel holds copies of the model object (one per TPU device), which are kept synchronized with identical weights. · 2) DataParallel ...
Read more >
PyTorch TPU starter - DeBERTa-v3-large (training) | Kaggle
Let's start training! To do so, we start by initializing the model. We use the xmp.MpModelWrapper provided by PyTorch XLA to save memory...
Read more >
Tensor Processing Unit (TPU) - PyTorch Lightning
All TPU VMs in a Pod setup are required to access the model code and data. One easy way to achieve this is...
Read more >
Getting Started with PyTorch on Cloud TPUs - Google Colab
PyTorch /XLA is a package that lets PyTorch connect to Cloud TPUs and use TPU cores as devices. Colab provides a free Cloud...
Read more >
huggingface/transformers: Model versioning, TensorFlow ...
... new scripts, refactor of the generate method Model versioning We host ... examples/docs: caveat that PL examples don't work on TPU #8309 ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found