question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

tuple index out of range for FlaxMBartForConditionalGeneration

See original GitHub issue

Environment info

Who can help

@patil-suraj @patrickvonplaten

Information

Model I am using: FlaxMBartForConditionalGeneration

The problem arises when loading the model itself

To reproduce

Steps to reproduce the behavior:

from transformers import FlaxMBartForConditionalGeneration
model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-5-f8556949d896> in <module>()
      1 from transformers import FlaxMBartForConditionalGeneration, MBart50TokenizerFast
      2 
----> 3 model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", from_pt=True)
      4 tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")

15 frames
/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
    336 
    337         # init random models
--> 338         model = cls(config, *model_args, **model_kwargs)
    339 
    340         if from_pt:

/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __init__(self, config, input_shape, seed, dtype, **kwargs)
    948     ):
    949         module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 950         super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
    951 
    952     def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:

/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in __init__(self, config, module, input_shape, seed, dtype)
    103 
    104         # randomly initialized parameters
--> 105         random_params = self.init_weights(self.key, input_shape)
    106 
    107         # save required_params as set

/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in init_weights(self, rng, input_shape)
    973             decoder_attention_mask,
    974             position_ids,
--> 975             decoder_position_ids,
    976         )["params"]
    977 

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init(self, rngs, method, mutable, *args, **kwargs)
    998     _, v_out = self.init_with_output(
    999         rngs, *args,
-> 1000         method=method, mutable=mutable, **kwargs)
   1001     return v_out
   1002 

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in init_with_output(self, rngs, method, mutable, *args, **kwargs)
    967       rngs = {'params': rngs}
    968     return self.apply(
--> 969         {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
    970 
    971   def init(self,

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
    937         method, self,
    938         mutable=mutable, capture_intermediates=capture_intermediates
--> 939     )(variables, *args, **kwargs, rngs=rngs)
    940 
    941   def init_with_output(self,

/usr/local/lib/python3.7/dist-packages/flax/core/scope.py in wrapper(variables, rngs, *args, **kwargs)
    685               **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:
    686     with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
--> 687       y = fn(root, *args, **kwargs)
    688     if mutable is not False:
    689       return y, root.mutable_variables()

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in scope_fn(scope, *args, **kwargs)
   1176     _context.capture_stack.append(capture_intermediates)
   1177     try:
-> 1178       return fn(module.clone(parent=scope), *args, **kwargs)
   1179     finally:
   1180       _context.capture_stack.pop()

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    273     _context.module_stack.append(self)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:
    277         filter_fn = _context.capture_stack[-1]

/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __call__(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
   1310             output_hidden_states=output_hidden_states,
   1311             return_dict=return_dict,
-> 1312             deterministic=deterministic,
   1313         )
   1314 

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    273     _context.module_stack.append(self)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:
    277         filter_fn = _context.capture_stack[-1]

/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __call__(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
    905             output_hidden_states=output_hidden_states,
    906             return_dict=return_dict,
--> 907             deterministic=deterministic,
    908         )
    909 

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    273     _context.module_stack.append(self)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:
    277         filter_fn = _context.capture_stack[-1]

/usr/local/lib/python3.7/dist-packages/transformers/models/mbart/modeling_flax_mbart.py in __call__(self, input_ids, attention_mask, position_ids, output_attentions, output_hidden_states, return_dict, deterministic)
    763         )
    764 
--> 765         last_hidden_states = outputs[0]
    766         last_hidden_states = self.layer_norm(last_hidden_states)
    767 

/usr/local/lib/python3.7/dist-packages/transformers/file_utils.py in __getitem__(self, k)
   1810             return inner_dict[k]
   1811         else:
-> 1812             return self.to_tuple()[k]
   1813 
   1814     def __setattr__(self, name, value):

IndexError: tuple index out of range

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
patil-surajcommented, Jul 8, 2021

This colab should help with how to use generate on TPU

1reaction
patil-surajcommented, Jul 8, 2021

Not sure, I will try to see what’s the issue with colab. But it should work just fine on TPU VM.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Error in fine-tuning BERT - Beginners - Hugging Face Forums
As a follow-up from my previous question, I am trying to fine-tune a model, but I am getting an error: IndexError: tuple index...
Read more >
IndexError: tuple index out of range ----- Python - Stack Overflow
Please Help me. I'm running a simple python program that will display the data from mySQL database in a tkinter form... from Tkinter...
Read more >
Python IndexError: tuple index out of range Solution
When you try to access an item in a tuple that does not exist, Python returns an error that says “tuple index out...
Read more >
Python IndexError: tuple index out of range - YouTube
Python IndexError: tuple index out of range. 11K views 7 years ago. ATOM. ATOM. 6.37K subscribers. Subscribe.
Read more >
self.foreign_related_fields[0] IndexError: tuple index out of range
from django.db import models from django.contrib.auth.models import ( AbstractBaseUser, BaseUserManager ) from organization.models import Organization
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found