question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Question] How to convert pkl to pt file?

See original GitHub issue

Thanks for your excellent work!

Describe the problem

I’d like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:

def get_random_image(generator: Generator, truncation_psi: float, seed):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
        if hasattr(generator.synthesis, 'input'):
            m = make_transform(translate=(0, 0), angle=0)
            m = np.linalg.inv(m)
            generator.synthesis.input.transform.copy_(torch.from_numpy(m))
        w = generator.mapping(z, None, truncation_psi=truncation_psi)
        img = generator.synthesis(w, noise_mode='const')
        res_image = tensor2im(img[0])
        return res_image, w

And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:

import pickle
import sys
from enum import Enum
from pathlib import Path
from typing import Optional

import torch

checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl"
print(f"Loading StyleGAN3 generator from path: {checkpoint_path}")
with open(checkpoint_path, "rb") as f:
    decoder = pickle.load(f)['G_ema'].cuda()
print('Loading done!')

state_dict = decoder.state_dict()
torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt")
print('Converting done!')

Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:

Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt
Traceback (most recent call last):
  File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint
    self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias".
	Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias".
	size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
	size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
	size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gen_images_using_pt.py", line 79, in <module>
    main()
  File "gen_images_using_pt.py", line 47, in main
    generator = SG3Generator(checkpoint_path=args.generator_path).decoder
  File "/sam/models/stylegan3/model.py", line 56, in __init__
    self._load_checkpoint(checkpoint_path)
  File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint
    self.decoder.load_state_dict(ckpt, strict=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
	size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
	size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
	size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
uselessaicommented, Jun 1, 2022

Thank you very much!!!.. you save my day!

1reaction
HuaZheLeicommented, Jun 1, 2022

Hi HuaZheLei, I am trying to generate images from a .pt model, but I am not sure how to load the model. How can I load the .pt model? Thanks!!

Thanks for your excellent work!

Describe the problem

I’d like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:

def get_random_image(generator: Generator, truncation_psi: float, seed):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
        if hasattr(generator.synthesis, 'input'):
            m = make_transform(translate=(0, 0), angle=0)
            m = np.linalg.inv(m)
            generator.synthesis.input.transform.copy_(torch.from_numpy(m))
        w = generator.mapping(z, None, truncation_psi=truncation_psi)
        img = generator.synthesis(w, noise_mode='const')
        res_image = tensor2im(img[0])
        return res_image, w

And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:

import pickle
import sys
from enum import Enum
from pathlib import Path
from typing import Optional

import torch

checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl"
print(f"Loading StyleGAN3 generator from path: {checkpoint_path}")
with open(checkpoint_path, "rb") as f:
    decoder = pickle.load(f)['G_ema'].cuda()
print('Loading done!')

state_dict = decoder.state_dict()
torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt")
print('Converting done!')

Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:

Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt
Traceback (most recent call last):
  File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint
    self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias".
	Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias".
	size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
	size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
	size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gen_images_using_pt.py", line 79, in <module>
    main()
  File "gen_images_using_pt.py", line 47, in main
    generator = SG3Generator(checkpoint_path=args.generator_path).decoder
  File "/sam/models/stylegan3/model.py", line 56, in __init__
    self._load_checkpoint(checkpoint_path)
  File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint
    self.decoder.load_state_dict(ckpt, strict=False)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
	size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
	size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
	size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
	size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

Hi, I will share my code here.

import os
import argparse
from typing import Tuple, List, Union

import numpy as np
import torch

from models.stylegan3.model import SG3Generator
from models.stylegan3.networks_stylegan3 import Generator
from utils.common import tensor2im

def make_transform(translate: Tuple[float, float], angle: float):
    m = np.eye(3)
    s = np.sin(angle / 360.0 * np.pi * 2)
    c = np.cos(angle / 360.0 * np.pi * 2)
    m[0][0] = c
    m[0][1] = s
    m[0][2] = translate[0]
    m[1][0] = -s
    m[1][1] = c
    m[1][2] = translate[1]
    return m

def main():
    args = parse_args()
    save_dir = args.save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    generator = SG3Generator(checkpoint_path=args.generator_path).decoder

    for i in range(args.image_numbers):
        print('Generating image for seed %d (%d/%d) ...' % (i, i, args.image_numbers))
        image, latent = get_random_image(generator, truncation_psi=args.truncation_psi, seed=i)
        image.save(os.path.join(save_dir, 'seed' + str(i).zfill(4) + '.png'))


def get_random_image(generator: Generator, truncation_psi: float, seed):
    with torch.no_grad():
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
        if hasattr(generator.synthesis, 'input'):
            m = make_transform(translate=(0, 0), angle=0)
            m = np.linalg.inv(m)
            generator.synthesis.input.transform.copy_(torch.from_numpy(m))
        w = generator.mapping(z, None, truncation_psi=truncation_psi)
        img = generator.synthesis(w, noise_mode='const')
        res_image = tensor2im(img[0])
        return res_image, w

Hope it helpful.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Convert Tensorflow StyleGAN2 model .pkl files to PyTorch .pt ...
To use GANspace or the Network bending repos, you'll need to convert your Tensorflow model files (. pkl ) to PyTorch files (....
Read more >
error in converting .pkl to .pt · Issue #250 - GitHub
I used google colab to convert from a model trained from a custom dataset with stylegan2-ADA-PyTorch with all default settings and got this ......
Read more >
Convert StyleGAN2 .pkl model to .pt - Google Colab
This requires we convert a .pkl model file to a .pt file. This notebook shows you the steps to do so. I also...
Read more >
Using custom StyleGAN2-ada network in GANSpace (.pkl to ...
Looking at the source code for converting to Pytorch, GANspace requires the pt file to be a dict with keys: ['g', 'g_ema', 'd',...
Read more >
Convert StyleGAN2 TensorFlow weights to Pytorch | by MLBoy
Convert the original checkpoint file pkl to pytorch format. There are various pre-trained weights such as anime. The Pytorch weight pt file ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found