T5 model seq2seq text generation using word embeddings instead of token_ids does not work
See original GitHub issueHi there,
I trained a MT5ForConditionalGeneration model. During training, I used my own embeddings for encoding (but default embeddings for decoding). However, when I try to generate output using generate function, it will give me an error message. I will post the code and error message in the following:
Here is the code for model training:
outputs = self.encoder2(inputs_embeds=context, attention_mask=input_mask, labels=padded_labels)
Where the context is similar to batch of token_ids but instead they are embeddings. The labels are target sequence token_ids. The training works fine without any issues.
And here is the line I tried to generate using the model:
outputs = self.encoder2.generate(input_ids=None, inputs_embeds=context, attention_mask=input_mask, bos_token_id=0, pad_token_id=0, eos_token_id=1)
And once the program hits the above line, I will get the following error message:
outputs = self.encoder2.generate(input_ids=None, inputs_embeds=context, attention_mask=input_mask, bos_token_id=0, pad_token_id=0, eos_token_id=1) File “/scratch/jerryc/jerryc/venv_py3.7/lib/python3.7/site-packages/torch/autograd/grad_mode.py”, line 27, in decorate_context return func(*args, **kwargs) File “/scratch/jerryc/jerryc/venv_py3.7/lib/python3.7/site-packages/transformers/generation_utils.py”, line 913, in generate input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id File “/scratch/jerryc/jerryc/venv_py3.7/lib/python3.7/site-packages/transformers/generation_utils.py”, line 422, in _prepare_decoder_input_ids_for_generation torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id AttributeError: ‘NoneType’ object has no attribute ‘shape’
It seems the model is not handling this case property. Any help would be appreciated. Thanks
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (4 by maintainers)
Top GitHub Comments
@ichiroex,
Thanks for the nicely reproducible code snippet - this is indeed a bug and should be fixed.
@patrickvonplaten Thank you!!