Checkpoint doesn't match
See original GitHub issueself.model = DalleBart.from_pretrained("model", dtype=dtype, abstract_init=True)
I got following error message for above line looks like the checkpoint doesn’t match?
ValueError: Trying to load the pretrained weight for ('model', 'decoder', 'embed_positions', 'embedding') failed: checkpoint has shape (257, 1024) which is incompatible with the model shape (256, 1024). Using
ignore_mismatched_sizes=True if you really want to load this checkpoint inside this model.
After updating to
self.model = DalleBart.from_pretrained("model", dtype=dtype, abstract_init=True, ignore_mismatched_sizes=True)
(add ignore_mismatched_sizes=True)
I got another error
TypeError: Value ShapeDtypeStruct(shape=(256, 1024), dtype=float32) with type <class ‘jax._src.api.ShapeDtypeStruct’> is not a valid JAX type
Looks like both errors are from JAX, I install the latest JAX by pip install jax[cuda11_cudnn805] -f
https://storage.googleapis.com/jax-releases/jax_releases.html
Any suggestions?
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
Top GitHub Comments
Awesome!
I updated transformers from 14.0 to 14.16, then it work well. Thank you very much! Best
Did you try to update your installation? Make sure your transformers repo is updated too!