question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TFBertForMaskedLM won't reload from saved checkpoint, shape mismatch issue

See original GitHub issue

Environment info

  • transformers version: 4.5.1-4.7
  • Platform: Debian GNU/Linux 10 (buster)
  • Python version: 3.9.2
  • PyTorch version (GPU?): N/A
  • Tensorflow version (GPU?): 2.5.0
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@Rocketknight1, @LysandreJik, @sgugger

Information

Model I am using: TFBertForMaskedLM

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • [ X] my own task or dataset: (give details below)

I believe this issue also affects the official TFTrainer implementation as the checkpoint restore snippet was adapted from it.

To reproduce

Steps to reproduce the behavior:

  1. Generate Masked Batch
  2. initialize TF Model and assign CheckpointManager
  3. Save model checkpoint
  4. initialize new TF Model and assign CheckpointManager
  5. restore from checkpoint
import numpy as np
from transformers import AutoTokenizer, TFAutoModelForMaskedLM, AutoConfig, TFAutoModelForCausalLM
import tensorflow as tf

random_sentences = ["You'll see the rainbow bridge after it rains cats and dogs.",
"They looked up at the sky and saw a million stars.",
"The bullet pierced the window shattering it before missing Danny's head by mere millimeters.",
"He was willing to find the depths of the rabbit hole in order to be with her."]

tok = AutoTokenizer.from_pretrained('bert-base-uncased')
input_ids = tok.batch_encode_plus(random_sentences,return_tensors='np',padding=True)['input_ids']

#Create masked tokens as labels
labels = np.ones_like(input_ids)*-100
mask = (np.random.uniform(size=input_ids.shape)<=0.2) & (input_ids != 0)
labels[mask]=tok.mask_token_id

batch= {'input_ids':tf.convert_to_tensor(input_ids),
        'labels':tf.convert_to_tensor(labels)}

"""## Run model and save checkpoint"""

model = TFAutoModelForMaskedLM.from_pretrained('bert-base-uncased')
checkpoint = tf.train.Checkpoint(model=model)
model.ckpt_manager = tf.train.CheckpointManager(checkpoint, './', max_to_keep=1)
out = model(**batch)
print(out.loss.numpy())
model.ckpt_manager.save()

"""## Re-Initialize from config alone an load existing checkpoint"""

cfg = AutoConfig.from_pretrained('bert-base-uncased')
model2 = TFAutoModelForMaskedLM.from_config(cfg)
checkpoint2 = tf.train.Checkpoint(model=model2)
model2.ckpt_manager = tf.train.CheckpointManager(checkpoint2, './', max_to_keep=1)
latest_ckpt = tf.train.latest_checkpoint('./')
status = checkpoint2.restore(latest_ckpt)
status.assert_existing_objects_matched()

out = model2(**batch)
print(out.loss.numpy())

Expected behavior

Expect to fully restore from checkpoint

Current Behavior, error output

ValueError                                Traceback (most recent call last)
<ipython-input-12-5ec2de12ee44> in <module>()
----> 1 out = model2(**batch)
      2 out.loss

19 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
   1238       raise ValueError(
   1239           "Tensor's shape %s is not compatible with supplied shape %s" %
-> 1240           (self.shape, shape))
   1241 
   1242   # Methods not supported / implemented for Eager Tensors.

ValueError: Tensor's shape (512, 768) is not compatible with supplied shape [2, 768]

Link to colab

https://colab.research.google.com/drive/12pwo4WSueOT523hh1INw5J_SLpkK0IgB?usp=sharing

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
Rocketknight1commented, Jun 21, 2021

Hey, thank you for that very helpful bit of diagnostic info! That links this with #11202, another issue we have caused by the same underlying problem. This is helpful because I’ll probably need to make some breaking changes to fix that issue, and the fact that it’s causing multiple downstream problems will increase the urgency there.

0reactions
github-actions[bot]commented, Jul 18, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

T5ForConditionalGeneration checkpoint size mismatch #19418
I trained a T5ForConditionalGeneration model and saved the checkpoint using PyTorch Lightning's Trainer to a .ckpt file.
Read more >
Issues saving and loading wav2vec2 models fine tuned using ...
After training some toy models, I realized that I couldn't load from the checkpoints or save and reload the model in the same...
Read more >
Size Mismatch Runtime Error When Trying to Load a PyTorch ...
It seems to me that your model configuration does not match the content of the model checkpoint. I imagine your model has parameters...
Read more >
RuntimeError: Error(s) in loading state_dict for DataParallel
size mismatch for module.rgb_converters.0.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.
Read more >
How to load the pre-trained BERT model from local/colab ...
You are using the Transformers library from HuggingFace. Since this library was initially written in Pytorch, the checkpoints are different ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found