Recommended way to partially load weights into a model ?
See original GitHub issueHi,
Is there a recommended way to partially load weights to a new model ?
At the moment I have an inept solution which uses the fact that fc
attribute is the only eqx.nn.Linear
instance in my model
is_leaf = lambda x : True if isinstance(x, eqx.nn.Linear) else False
new_model = tree_deserialise_leaves('model.eqx', CNN(num_classes=26, key=key), is_leaf=is_leaf)
This approach won’t work where there are multiple Linear
layers. Is there a recommended way to perform such loading?
A name
parameter for modules might be useful here.
Issue Analytics
- State:
- Created a year ago
- Comments:28 (28 by maintainers)
Top Results From Across the Web
How to load only specific weights on Keras - Stack Overflow
This will update weights only in the layers of your new model that have an identically named layer found in the original trained...
Read more >Can I load partial pretrain weights #3293 - ultralytics/yolov5
Start from Pretrained weights. Recommended for small to medium sized datasets (i.e. VOC, VisDrone, GlobalWheat). Pass the name of the model ...
Read more >How to load part of pre trained model? - PyTorch Forums
Splitting Pre-Trained Model by its Parameters. How to tranfer weight of trained model and map on which have fewer classes?
Read more >Handling big models - Hugging Face
Load those weights inside the model. While this works very well for regularly sized models, this workflow has some clear limitations when we...
Read more >How to load a partially trained deep learning model ... - YouTube
Code generated in the video can be downloaded from here: https://github.com/bnsreenu/python_for_microscopists.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yep, I could see some documentation being a good idea. (Including a version of my above example?)
Right – so the answer is that whilst
tree_map
just-so-happens to be ordered, this is actually an implementation detail. That is, JAX reserves the right to change this behaviour about ordering in the future.In contrast
tree_flatten
does guarantee that it will return things in order, and JAX actually includes a test that this is the case. This is from here: https://github.com/google/jax/pull/11658 (Which as you can see is very recent. The context is that I raised this question internally with the JAX team, and it was agreed to guarantee thattree_flatten
will preserve ordering, but not to offer the same guarantee fortree_map
.)The current implementation of
tree_map
is a lot neater than my approach though! I suggest we copy-paste that for ourordered_tree_map
, instead of the one I wrote.