[Question] How to convert pkl to pt file?
See original GitHub issueThanks for your excellent work!
Describe the problem
I’d like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:
def get_random_image(generator: Generator, truncation_psi: float, seed):
with torch.no_grad():
z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
if hasattr(generator.synthesis, 'input'):
m = make_transform(translate=(0, 0), angle=0)
m = np.linalg.inv(m)
generator.synthesis.input.transform.copy_(torch.from_numpy(m))
w = generator.mapping(z, None, truncation_psi=truncation_psi)
img = generator.synthesis(w, noise_mode='const')
res_image = tensor2im(img[0])
return res_image, w
And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:
import pickle
import sys
from enum import Enum
from pathlib import Path
from typing import Optional
import torch
checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl"
print(f"Loading StyleGAN3 generator from path: {checkpoint_path}")
with open(checkpoint_path, "rb") as f:
decoder = pickle.load(f)['G_ema'].cuda()
print('Loading done!')
state_dict = decoder.state_dict()
torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt")
print('Converting done!')
Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:
Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt
Traceback (most recent call last):
File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint
self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias".
Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias".
size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "gen_images_using_pt.py", line 79, in <module>
main()
File "gen_images_using_pt.py", line 47, in main
generator = SG3Generator(checkpoint_path=args.generator_path).decoder
File "/sam/models/stylegan3/model.py", line 56, in __init__
self._load_checkpoint(checkpoint_path)
File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint
self.decoder.load_state_dict(ckpt, strict=False)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
Issue Analytics
- State:
- Created a year ago
- Comments:5 (1 by maintainers)
Top Results From Across the Web
Convert Tensorflow StyleGAN2 model .pkl files to PyTorch .pt ...
To use GANspace or the Network bending repos, you'll need to convert your Tensorflow model files (. pkl ) to PyTorch files (....
Read more >error in converting .pkl to .pt · Issue #250 - GitHub
I used google colab to convert from a model trained from a custom dataset with stylegan2-ADA-PyTorch with all default settings and got this ......
Read more >Convert StyleGAN2 .pkl model to .pt - Google Colab
This requires we convert a .pkl model file to a .pt file. This notebook shows you the steps to do so. I also...
Read more >Using custom StyleGAN2-ada network in GANSpace (.pkl to ...
Looking at the source code for converting to Pytorch, GANspace requires the pt file to be a dict with keys: ['g', 'g_ema', 'd',...
Read more >Convert StyleGAN2 TensorFlow weights to Pytorch | by MLBoy
Convert the original checkpoint file pkl to pytorch format. There are various pre-trained weights such as anime. The Pytorch weight pt file ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thank you very much!!!.. you save my day!
Hi, I will share my code here.
Hope it helpful.