question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Flax models should allow `inputs_embeds`

See original GitHub issue

Feature request

Currently, non-Flax models allow inputs_embeds instead of input_ids (e.g., GPT2)

def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        ...
        inputs_embeds: Optional[torch.FloatTensor] = None,
        ...
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
    ...

However, Flax models have no such option (input_ids only).

It would be great if Flax models also had this option so that,

Optionally, instead of passing input_ids you can choose to directly pass an embedded representation.

Motivation

This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix. (from the docs)

Additionally, this can be useful for things like tuning “soft-prompts” (e.g., https://aclanthology.org/2021.emnlp-main.243/)

Your contribution

I’m will try to implement this myself, but I haven’t yet found a solution.

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:2
  • Comments:16 (9 by maintainers)

github_iconTop GitHub Comments

2reactions
BigRedTcommented, Aug 13, 2022

@sanchit-gandhi here’s the PR you requested. I was actually able to get it to work with minimal modifications to generation_flax_utils.py.

@mattf1n a similar solution might work for GPT-2 as well?

1reaction
BigRedTcommented, Aug 10, 2022

I would like to second @mattf1n 's feature request. This would be super useful for vision-language modeling where we often want to feed a concatenation of image and text features into a sequence-to-sequence model. This approach has become quite popular recently. See for example - VL-T5, GPV-1, GPV-2, UnifiedIO. And given that non-Flax models already support this, would be great to have this implemented for Flax models as well for consistency!

Read more comments on GitHub >

github_iconTop Results From Across the Web

BART - Hugging Face
The BART Model with a language modeling head. Can be used for summarization. This model inherits from PreTrainedModel. Check the superclass documentation for ......
Read more >
Flax Basics - Read the Docs
Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll...
Read more >
Marc van Zee (@marcvanzee) / Twitter
You can now directly convert JAX functions and Flax machine learning models to TensorFlow.js! Check out this blog post that shows how to...
Read more >
Flax Model Surgery | Kaggle
Let's dive right in! Set up Kaggle / JAX Runtime (GPU will suffice for this kernel)¶. In [1]:.
Read more >
Writing a Training Loop in JAX + FLAX - WandB
Well, let us try and convince you why you should try the (JAX + ... Instead of a forward in PyTorch models or...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found