How to retrain unconditional Imagen model
See original GitHub issueI trained a model that is not conditioned on text, following this description on the README:
unet = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,
layer_attns = (False, False, False, True),
layer_cross_attns = False
)
# imagen, which contains the unet above
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
image_sizes = 128,
timesteps = 1000
)
trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True, # whether to split the validation dataset from the training
# imagen_checkpoint_path = "./ckpt.pt"
).cuda()
But to retrain, I added the imagen_checkpoint_path
arg in the ImagenTrainer
, with the path to my previously trained model. But I get an error message stating that the checkpoint must contain a config. I traced this and discovered it is because I used Imagen and not ImagenConfig, which is understandable.
So now that I have a model trained with Imagen that is not conditioned on text, is there a way I can load the model and retrain?
I have tried the following:
# load checkpoint
trainer.load("./ckpt.pt")
dataset = Dataset('./image_data/', image_size = 128)
trainer.add_train_dataset(dataset, batch_size = 16)
And that worked, except that I could not get the loss of each timestep in this new training because whenever I run:
for i in range(20000):
loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
print(f'loss: {loss}')`
It gives me the following error:
RuntimeError: output with shape [128, 1, 1] doesn’t match the broadcast shape [1, 128, 1, 1]
Is there a way to retrain a model with Imagen that is not conditioned with text, and also get the loss values in each timestep?
Issue Analytics
- State:
- Created a year ago
- Comments:10 (6 by maintainers)
Top GitHub Comments
Great! I will try these. Thank you, @lucidrains.
I’ve just upgraded from a lower version (1.9.3) to test SPD-Conv and had a similar error, I had to pass
strict=False
like lucid said but alsoonly_model=True
. For my unet passinglr=1e-5
to the trainer was also good enough that the loss was initially only a bit higher than usual.