question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to retrain unconditional Imagen model

See original GitHub issue

I trained a model that is not conditioned on text, following this description on the README:

unet = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 1,
    layer_attns = (False, False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unet above

imagen = Imagen(
    condition_on_text = False,  # this must be set to False for unconditional Imagen
    unets = unet,
    image_sizes = 128,
    timesteps = 1000
)

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = True, # whether to split the validation dataset from the training
    # imagen_checkpoint_path = "./ckpt.pt"
).cuda()

But to retrain, I added the imagen_checkpoint_path arg in the ImagenTrainer, with the path to my previously trained model. But I get an error message stating that the checkpoint must contain a config. I traced this and discovered it is because I used Imagen and not ImagenConfig, which is understandable.

So now that I have a model trained with Imagen that is not conditioned on text, is there a way I can load the model and retrain?

I have tried the following:

# load checkpoint
trainer.load("./ckpt.pt")

dataset = Dataset('./image_data/', image_size = 128)

trainer.add_train_dataset(dataset, batch_size = 16)

And that worked, except that I could not get the loss of each timestep in this new training because whenever I run:

for i in range(20000):

    loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
    print(f'loss: {loss}')`

It gives me the following error:

RuntimeError: output with shape [128, 1, 1] doesn’t match the broadcast shape [1, 128, 1, 1]

Is there a way to retrain a model with Imagen that is not conditioned with text, and also get the loss values in each timestep?

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:10 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
sharonibejihcommented, Aug 25, 2022

Great! I will try these. Thank you, @lucidrains.

1reaction
Nodjacommented, Aug 25, 2022

I’ve just upgraded from a lower version (1.9.3) to test SPD-Conv and had a similar error, I had to pass strict=False like lucid said but also only_model=True. For my unet passing lr=1e-5 to the trainer was also good enough that the loss was initially only a bit higher than usual.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Unconditional Image-Generation - Hugging Face
In this section, we explain how one can train an unconditional image generation diffusion model. “Unconditional” because the model is not conditioned on...
Read more >
Unconditional Image Generation Using Hugging Face Diffusers
Training examples to show how to train the most popular diffusion model tasks. For more information, check out the official docs for Training....
Read more >
Training Your Own Unconditional Diffusion Model ... - PeakD
Step 1: Gathering your dataset ... This section is more or less a direct port from my 2020 piece on training GANS, since...
Read more >
How Imagen Actually Works - AssemblyAI
Diffusion Models are a method of creating data that is similar to a set of training data. They train by destroying the training...
Read more >
Inference and train with existing models and standard datasets
Since unconditional models only need real images for training and testing, all you need to do is link your dataset to the data...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found