question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

PreTrainedModel's tie_weights invocation needs to be configurable

See original GitHub issue

PreTrainedModel defines tie_weights method and then in one place suggests

Takes care of tying weights embeddings afterwards if the model class has a :obj:tie_weights() method.

But since the super-class has it defined, it’s always there.

So the only way for a sub-class to avoid this “tying” is to override it with:

     def tie_weights(self): pass

if nothing else happens that comment needs to be edited to suggest a noop override in the sub-class.

But it took some hunting to get there, so a better solution is needed.

Most likely, currently, most (all?) models in transformers with encoder/decoder share token embed weights, hence the issue didn’t come up. I’m working on porting a fairseq transformer and there the enc/dec token embeds aren’t shared.

I propose a solution which adds a new param to PretrainedConfig, say: is_enc_dec_sharing_embeds=True and let the subclass override those, then add at the start of tie_weights in modeling_utils.py

     def tie_weights(self):
         if not self.config.is_enc_dec_sharing_embeds:
             return

that way it’s easy to quickly become aware that an action needs to be taken and set the desired behavior from within the subclass.

Thoughts?

If the proposed solution is agreeable, please, let me know which config param name it should be is_enc_dec_sharing_embeds - or different and I will submit a PR.

Thank you.

edit:

OK, having had a closer look:

grep -r -A 2  'def tie_weights' src/transformers | grep pass | wc -l

we have 5 sub-classes that override it with a no-op so only some rely on the default. Bad superclass, no cookies for you.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
patrickvonplatencommented, Aug 24, 2020

Awesome, I will open a PR. I actually need this feature for the EncoderDecoderModel as well.

1reaction
LysandreJikcommented, Aug 24, 2020

This sounds like a good idea. I would advocate for a tie_word_embeddings parameter in the configuration as @patrickvonplaten suggested, but I would keep tie_weights as the method that does the weight tying rather than renaming that method as well. Just a quick glance at the configuration tells you which weights it’s about to tie, and it will able to handle other cases of weight tying that we might encounter in the future without the need of adding additional new methods.

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found