past_key_values not accepted in generate with GPTNeoX
See original GitHub issueSystem Info
Python 3.7.13 transformers 4.22.2
Who can help?
@LysandreJik @patrickvonplaten
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
The past_key_values
kwarg is not accepted when calling model.generate(..., past_key_values=pkv)
on a GPTNeoxForCausalLM
, even though the model.forward
does accept this kwarg. It does seem to work fine with other model classes like GPT2.
Minimal example to reproduce error:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers
model_id = "NinedayWang/PolyCoder-160M" # small model with GPTNeoXForCausalLM class
model = AutoModelForCausalLM.from_pretrained(model_id)
tok = AutoTokenizer.from_pretrained(model_id)
assert isinstance(model, transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM)
pkv = torch.rand(
(
1, # batch size
10, # number of tokens
2 * model.config.num_hidden_layers,
model.config.num_attention_heads,
model.config.hidden_size // model.config.num_attention_heads
)
)
out = model.generate(**tok("Hello world"), past_key_values=pkv)
Error message:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/transformers/generation_utils.py", line 1146, in generate
self._validate_model_kwargs(model_kwargs.copy())
File "/home/st/st_us-052400/st_st175337/conda/envs/thesis/lib/python3.7/site-packages/transformers/generation_utils.py", line 862, in _validate_model_kwargs
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
ValueError: The following `model_kwargs` are not used by the model: ['past_key_values'] (note: typos in the generate arguments will also show up in this list)
I checked the error location and located the bug (“transformers/generation_utils.py”, line 862, in _validate_model_kwargs):
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
It first checks the args of prepare_inputs_for_generation
and only adds the args of forward
to the accepted list if "kwargs"
is in the args of prepare_inputs_for_generation
. However, contrary to GPT2, it only contains model_kwargs
instead of kwargs
for GPTNeox.
So either the GPTNeoX class should be adapted, or the _validate_model_kwargs method in generation_utils.py.
Expected behavior
generate
should be able to pass along all valid model_kwargs
Issue Analytics
- State:
- Created 10 months ago
- Comments:6 (4 by maintainers)
@gante @ArthurZucker I think we should rename all occurrences of
"past"
to"past_key_values"
inprepare_inputs_for_generation
and deprecate “past” if necessary."past"
was simply the name for the past key values states before we renamed everything topast_key_values
, so this is just a left-over.Agreed