tuple index out of range for FlaxMBartForConditionalGeneration
See original GitHub issueEnvironment info
transformers
version: 4.9.0.dev0 (installed from source)- Platform: Google colab
- Python version: 3.7.10
- Using TPU in script?: Yes
- Dependecies were installed following this colab notebook: https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/causal_language_modeling_flax.ipynb#scrollTo=Sj1mJNJa6PPS
Who can help
@patil-suraj @patrickvonplaten
Information
Model I am using: FlaxMBartForConditionalGeneration
The problem arises when loading the model itself
To reproduce
Steps to reproduce the behavior:
from transformers import FlaxMBartForConditionalGeneration
model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-5-f8556949d896> in <module>()
1 from transformers import FlaxMBartForConditionalGeneration, MBart50TokenizerFast
2
----> 3 model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", from_pt=True)
4 tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
15 frames
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
336
337 # init random models
--> 338 model = cls(config, *model_args, **model_kwargs)
339
340 if from_pt:
/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __init__(self, config, input_shape, seed, dtype, **kwargs)
948 ):
949 module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 950 super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
951
952 def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in __init__(self, config, module, input_shape, seed, dtype)
103
104 # randomly initialized parameters
--> 105 random_params = self.init_weights(self.key, input_shape)
106
107 # save required_params as set
/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in init_weights(self, rng, input_shape)
973 decoder_attention_mask,
974 position_ids,
--> 975 decoder_position_ids,
976 )["params"]
977
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init(self, rngs, method, mutable, *args, **kwargs)
998 _, v_out = self.init_with_output(
999 rngs, *args,
-> 1000 method=method, mutable=mutable, **kwargs)
1001 return v_out
1002
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init_with_output(self, rngs, method, mutable, *args, **kwargs)
967 rngs = {'params': rngs}
968 return self.apply(
--> 969 {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
970
971 def init(self,
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
937 method, self,
938 mutable=mutable, capture_intermediates=capture_intermediates
--> 939 )(variables, *args, **kwargs, rngs=rngs)
940
941 def init_with_output(self,
/usr/local/lib/python3.7/dist-packages/flax/core/scope.py in wrapper(variables, rngs, *args, **kwargs)
685 **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:
686 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
--> 687 y = fn(root, *args, **kwargs)
688 if mutable is not False:
689 return y, root.mutable_variables()
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in scope_fn(scope, *args, **kwargs)
1176 _context.capture_stack.append(capture_intermediates)
1177 try:
-> 1178 return fn(module.clone(parent=scope), *args, **kwargs)
1179 finally:
1180 _context.capture_stack.pop()
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
273 _context.module_stack.append(self)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
277 filter_fn = _context.capture_stack[-1]
/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __call__(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
1310 output_hidden_states=output_hidden_states,
1311 return_dict=return_dict,
-> 1312 deterministic=deterministic,
1313 )
1314
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
273 _context.module_stack.append(self)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
277 filter_fn = _context.capture_stack[-1]
/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __call__(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
905 output_hidden_states=output_hidden_states,
906 return_dict=return_dict,
--> 907 deterministic=deterministic,
908 )
909
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
273 _context.module_stack.append(self)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
277 filter_fn = _context.capture_stack[-1]
/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __call__(self, input_ids, attention_mask, position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
763 )
764
--> 765 last_hidden_states = outputs[0]
766 last_hidden_states = self.layer_norm(last_hidden_states)
767
/usr/local/lib/python3.7/dist-packages/transformers/file_utils.py in __getitem__(self, k)
1810 return inner_dict[k]
1811 else:
-> 1812 return self.to_tuple()[k]
1813
1814 def __setattr__(self, name, value):
IndexError: tuple index out of range
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (7 by maintainers)
Top Results From Across the Web
Error in fine-tuning BERT - Beginners - Hugging Face Forums
As a follow-up from my previous question, I am trying to fine-tune a model, but I am getting an error: IndexError: tuple index...
Read more >IndexError: tuple index out of range ----- Python - Stack Overflow
Please Help me. I'm running a simple python program that will display the data from mySQL database in a tkinter form... from Tkinter...
Read more >Python IndexError: tuple index out of range Solution
When you try to access an item in a tuple that does not exist, Python returns an error that says “tuple index out...
Read more >Python IndexError: tuple index out of range - YouTube
Python IndexError: tuple index out of range. 11K views 7 years ago. ATOM. ATOM. 6.37K subscribers. Subscribe.
Read more >self.foreign_related_fields[0] IndexError: tuple index out of range
from django.db import models from django.contrib.auth.models import ( AbstractBaseUser, BaseUserManager ) from organization.models import Organization
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
This colab should help with how to use generate on TPU
Not sure, I will try to see what’s the issue with colab. But it should work just fine on TPU VM.