question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unable to train due to recent changes.

See original GitHub issue

After the latest changes, I tweaked my training code in order to save and then reload in between training each unet. But I keep getting errors no matter how I configure my training loop. The basic layout of my code looks something like this.

# Omitted some code for simplicity
# Using custom dataloader 

unet1 = Unet(
    dim = 192,
    memory_efficient = False,
)

unet2 = Unet(
    dim = 64,
    memory_efficient = True,
)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = "t5-large", 
    image_sizes = (64, 256), 
    cond_drop_prob = 0.1,
    timesteps = 1000,
).cuda()

for epoch in range(1, 51):
    for i in (1,2):
        # Need to instantiate trainer here, otherwise I get error "AssertionError: you cannot only train on one unet at a time ... "
        trainer = ImagenTrainer(
            imagen, 
            lr = 1e-5, 
            fp16 = True
        )

        # load latest checkpoint (unless at the very beginning of training with i=1)
        if epoch != 1 and i != 1:
            trainer.load('./path/to/checkpoint.pt')

        for step, batch in enumerate(dataloader): 
            images, texts = batch
            images = images.to(device)

            text_embeds, text_masks = get_emb_tensor(cfg, texts, device)

            # Error occurs when i=2
            loss = trainer(
                images, 
                text_embeds = text_embeds, 
                text_masks = text_masks, 
                unet_number = i, 
                max_batch_size = 32
            )
            
            trainer.update(unet_number = i)
    
        trainer.save('./path/to/checkpoint.pt')

When running the code above I get the following error

checkpoint saved to checkpoint\imagen_large_checkpoint.pt
checkpoint loaded from checkpoint\imagen_large_checkpoint.pt
Traceback (most recent call last):
  File "C:\Users\camla\PythonProjects\imagen_pytorch\imagen_train.py", line 192, in <module>
    run_train_loop(cfg, imagen, coco_dataloader, device)
  File "C:\Users\camla\PythonProjects\imagen_pytorch\imagen_train.py", line 104, in run_train_loop
    trainer, loss_arr = train(cfg, dataloader, trainer, epoch, i, device)
  File "C:\Users\camla\PythonProjects\imagen_pytorch\imagen_train.py", line 55, in train
    loss = trainer(
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\imagen_pytorch\trainer.py", line 116, in inner
    out = fn(model, *args, **kwargs)
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\imagen_pytorch\trainer.py", line 712, in forward
    loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\torch\nn\modules\module.py", line 1185, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'ImagenTrainer' object has no attribute 'unet_being_trained'

As a possible fix, I tried setting the unet number when instantiating the ImagenTrainer, rather than when calling loss = trainer(…). So my code looked like this instead.

for epoch in range(1, 51):
    for i in (1,2):
        # Need to instantiate trainer here, otherwise I get error "AssertionError: you cannot only train on one unet at a time ... "
        # Error occurs here
        trainer = ImagenTrainer(
            imagen, 
            lr = 1e-5, 
            fp16 = True,
            only_train_unet_number = i
        )

        if epoch != 1 and i != 1:
            trainer.load('./path/to/checkpoint.pt')

        loss_arr = []
        for step, batch in enumerate(dataloader): 
            images, texts = batch
            images = images.to(device)

            text_embeds, text_masks = get_emb_tensor(cfg, texts, device)

            loss = trainer(
                images, 
                text_embeds = text_embeds, 
                text_masks = text_masks, 
                max_batch_size = 32
            )
            
            trainer.update(unet_number = i)
            loss_arr.append(loss)
            
        trainer.save('./path/to/checkpoint.pt')

But this threw a different error …

Epoch 1/50
Traceback (most recent call last):
  File "C:\Users\camla\PythonProjects\imagen_pytorch\imagen_train.py", line 191, in <module>
    run_train_loop(cfg, imagen, coco_dataloader, device)
  File "C:\Users\camla\PythonProjects\imagen_pytorch\imagen_train.py", line 81, in run_train_loop
    trainer = ImagenTrainer(
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\imagen_pytorch\trainer.py", line 247, in __init__
    self.validate_and_set_unet_being_trained(only_train_unet_number)
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\imagen_pytorch\trainer.py", line 347, in validate_and_set_unet_being_trained
    self.wrap_unet(unet_number)
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\imagen_pytorch\trainer.py", line 82, in inner
    out = fn(*args, **kwargs)
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\imagen_pytorch\trainer.py", line 355, in wrap_unet
    optimizer = getattr(self, f'optim{unet_index}')
  File "C:\Users\camla\anaconda3\envs\torch-coco\lib\site-packages\torch\nn\modules\module.py", line 1185, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'ImagenTrainer' object has no attribute 'optim0'

Is this a bug? Or am I doing something totally wrong?

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:8 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
camlaedtkecommented, Jul 13, 2022

@lucidrains Great thank you, I’ll try it out. Or I guess I could just go out and buy a second GPU we’ll see lol.

1reaction
camlaedtkecommented, Jul 13, 2022

@lucidrains Thanks. So if I wanted to train for 50 epochs, I would need to train only unet 1 for 50 epochs, and then run the program again to train only unet 2 for 50 epochs? Is there any way to train each unet once per epoch? That would be useful because it’s nice to see the evolution of sample images over the course of training.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Overtraining: What It Is, Symptoms, and Recovery - HSS
Training-related signs of overtraining. Unusual muscle soreness after a workout, which persists with continued training; Inability to train ...
Read more >
How to Train to Failure for Maximum Muscle Growth
Training to failure 101​​ If your typical legs workout is 4 sets of 5 reps of back squats, change the last set to...
Read more >
How Should One Alternate Their Workouts To Avoid Adaptation?
Aside from that, it can hinder your results and make your workouts tedious. Plateaus occur in training due to a phenomenon known as...
Read more >
Does Training to Failure Help You Build More Muscle? What ...
Key Takeaways. The main reason people train to failure is they think it'll increase muscle and strength gain by increasing muscle activation.
Read more >
Overtraining Syndrome/Burnout - Rady Children's Hospital
Overtraining can result in mood changes, decreased motivation, frequent injuries and even infections. Burnout is thought to be a result of the physical...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found