question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`FlaxBartForConditionalGeneration` should not require `input_ids` when `encoder_output` is provided

See original GitHub issue

Environment info

  • transformers version: 4.18.0
  • Platform: Linux-5.11.0-1018-gcp-x86_64-with-glibc2.31
  • Python version: 3.10.4
  • Huggingface_hub version: 0.5.1
  • PyTorch version (GPU?): 1.11.0+cu102 (False)
  • Tensorflow version (GPU?): 2.8.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.4.1 (cpu)
  • Jax version: 0.3.5
  • JaxLib version: 0.3.5
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@patil-suraj

Information

Model I am using: BART

To reproduce

from transformers import BartTokenizer, BartForConditionalGeneration, FlaxBartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model_flax = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-base')
inputs = tokenizer('Travelers wait about an hour and a half to cross the Tower.', return_tensors='jax')
outputs_flax = model_flax.encode(**inputs)
generate_ids_flax = model_flax.generate(attention_mask=inputs.attention_mask, encoder_output=outputs_flax)  # TypeError: FlaxGenerationMixin.generate() missing 1 required positional argument: 'input_ids'

import numpy as onp
import torch
from transformers.modeling_outputs import BaseModelOutput

def jax2pt(a):
    return torch.from_numpy(onp.asarray(a))

model_pt = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
outputs_pt = BaseModelOutput(last_hidden_state=jax2pt(outputs_flax.last_hidden_state))
generate_ids_pt = model_pt.generate(attention_mask=jax2pt(inputs.attention_mask), encoder_outputs=outputs_pt)
print(generate_ids_pt)  # OK

Expected behavior

The Flax model should work as the PyTorch model.

Actual behavior

FlaxBartForConditionalGeneration requires input_ids, even if encoder_output is provided.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:9 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
ayaka14732commented, May 15, 2022

I am going to make a PR

1reaction
patil-surajcommented, Apr 21, 2022

We can support passing encoder_outputs in flax generate. Would you like to open a PR for this ? Happy to help with it.

We’ll need to modify this method to skip calling .encode if the encoder_outputs are passed as a kwarg. https://github.com/huggingface/transformers/blob/bae9b6458cb4aebaf3a2eea1ab5d47904062f30f/src/transformers/generation_flax_utils.py#L142-L149

Read more comments on GitHub >

github_iconTop Results From Across the Web

BART - Hugging Face
An example of how to train BartForConditionalGeneration with a Hugging ... BART does not make use of token type ids, therefore a list...
Read more >
Transformers BART Model Explained for Text Summarization
The original Transformer is based on an encoder-decoder architecture and is a classic sequence-to-sequence model. The model's input and output ...
Read more >
Hugging Face Pre-trained Models: Find the Best One for Your ...
Here we will instantiate a model that contains a base transformer module, given inputs, it will produce outputs i.e a high dimensional vector....
Read more >
JAX Implementation of bart-base - Model Zoo
For instance, the following code saves a Flax model and reload it as a PyTorch model: with tempfile. ... TODO: Confirm that Transformer...
Read more >
Bringing Back MLPs – Weights & Biases - WandB
Convolutions maps single inputs to single outputs whereas in ... NOTE: Unlike Transformers gMLP does not require position embeddings.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found