Flax models should allow `inputs_embeds`
See original GitHub issueFeature request
Currently, non-Flax models allow inputs_embeds
instead of input_ids
(e.g., GPT2)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
...
inputs_embeds: Optional[torch.FloatTensor] = None,
...
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
...
However, Flax models have no such option (input_ids
only).
It would be great if Flax models also had this option so that,
Optionally, instead of passing
input_ids
you can choose to directly pass an embedded representation.
Motivation
This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix. (from the docs)
Additionally, this can be useful for things like tuning “soft-prompts” (e.g., https://aclanthology.org/2021.emnlp-main.243/)
Your contribution
I’m will try to implement this myself, but I haven’t yet found a solution.
Issue Analytics
- State:
- Created a year ago
- Reactions:2
- Comments:16 (9 by maintainers)
Top Results From Across the Web
BART - Hugging Face
The BART Model with a language modeling head. Can be used for summarization. This model inherits from PreTrainedModel. Check the superclass documentation for ......
Read more >Flax Basics - Read the Docs
Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll...
Read more >Marc van Zee (@marcvanzee) / Twitter
You can now directly convert JAX functions and Flax machine learning models to TensorFlow.js! Check out this blog post that shows how to...
Read more >Flax Model Surgery | Kaggle
Let's dive right in! Set up Kaggle / JAX Runtime (GPU will suffice for this kernel)¶. In [1]:.
Read more >Writing a Training Loop in JAX + FLAX - WandB
Well, let us try and convince you why you should try the (JAX + ... Instead of a forward in PyTorch models or...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@sanchit-gandhi here’s the PR you requested. I was actually able to get it to work with minimal modifications to
generation_flax_utils.py
.@mattf1n a similar solution might work for GPT-2 as well?
I would like to second @mattf1n 's feature request. This would be super useful for vision-language modeling where we often want to feed a concatenation of image and text features into a sequence-to-sequence model. This approach has become quite popular recently. See for example - VL-T5, GPV-1, GPV-2, UnifiedIO. And given that non-Flax models already support this, would be great to have this implemented for Flax models as well for consistency!