`FlaxBartForConditionalGeneration` should not require `input_ids` when `encoder_output` is provided
See original GitHub issueEnvironment info
transformers
version: 4.18.0- Platform: Linux-5.11.0-1018-gcp-x86_64-with-glibc2.31
- Python version: 3.10.4
- Huggingface_hub version: 0.5.1
- PyTorch version (GPU?): 1.11.0+cu102 (False)
- Tensorflow version (GPU?): 2.8.0 (False)
- Flax version (CPU?/GPU?/TPU?): 0.4.1 (cpu)
- Jax version: 0.3.5
- JaxLib version: 0.3.5
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
Information
Model I am using: BART
To reproduce
from transformers import BartTokenizer, BartForConditionalGeneration, FlaxBartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model_flax = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-base')
inputs = tokenizer('Travelers wait about an hour and a half to cross the Tower.', return_tensors='jax')
outputs_flax = model_flax.encode(**inputs)
generate_ids_flax = model_flax.generate(attention_mask=inputs.attention_mask, encoder_output=outputs_flax) # TypeError: FlaxGenerationMixin.generate() missing 1 required positional argument: 'input_ids'
import numpy as onp
import torch
from transformers.modeling_outputs import BaseModelOutput
def jax2pt(a):
return torch.from_numpy(onp.asarray(a))
model_pt = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
outputs_pt = BaseModelOutput(last_hidden_state=jax2pt(outputs_flax.last_hidden_state))
generate_ids_pt = model_pt.generate(attention_mask=jax2pt(inputs.attention_mask), encoder_outputs=outputs_pt)
print(generate_ids_pt) # OK
Expected behavior
The Flax model should work as the PyTorch model.
Actual behavior
FlaxBartForConditionalGeneration
requires input_ids
, even if encoder_output
is provided.
Issue Analytics
- State:
- Created a year ago
- Comments:9 (7 by maintainers)
Top Results From Across the Web
BART - Hugging Face
An example of how to train BartForConditionalGeneration with a Hugging ... BART does not make use of token type ids, therefore a list...
Read more >Transformers BART Model Explained for Text Summarization
The original Transformer is based on an encoder-decoder architecture and is a classic sequence-to-sequence model. The model's input and output ...
Read more >Hugging Face Pre-trained Models: Find the Best One for Your ...
Here we will instantiate a model that contains a base transformer module, given inputs, it will produce outputs i.e a high dimensional vector....
Read more >JAX Implementation of bart-base - Model Zoo
For instance, the following code saves a Flax model and reload it as a PyTorch model: with tempfile. ... TODO: Confirm that Transformer...
Read more >Bringing Back MLPs – Weights & Biases - WandB
Convolutions maps single inputs to single outputs whereas in ... NOTE: Unlike Transformers gMLP does not require position embeddings.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I am going to make a PR
We can support passing
encoder_outputs
in flax generate. Would you like to open a PR for this ? Happy to help with it.We’ll need to modify this method to skip calling
.encode
if theencoder_outputs
are passed as a kwarg. https://github.com/huggingface/transformers/blob/bae9b6458cb4aebaf3a2eea1ab5d47904062f30f/src/transformers/generation_flax_utils.py#L142-L149