question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Lr-finder with multiple inputs, outputs and losses

See original GitHub issue

Hello,

Firstly, thank you for this wonderful library. I have a model which expects 2 inputs. I am working with 2 kinds of images, one of size (512, 1536) and the other of size (128, 384). Therefore, my train_loader contains 2 inputs and one target of shape (128, 384, 16). My model has 4 prediction heads and hence is trained using 4 losses for different purposes.

So my collate_fn for the data loader looks like this:

def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).
    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations
    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on
                                 0 dim
    """
    targets = []
    imgs = []
    deps = []
    for sample in batch:
        imgs.append(sample[0])
        deps.append(sample[1])
        targets.append(sample[2])
    return torch.stack(imgs, 0), torch.stack(deps, 0), torch.stack(targets, 0)

As mentioned, there are 4 different losses: Custom Heatmap (Focal) loss, SmoothL1, SmoothL1, BCE loss.

The forward method of the model expects 2 inputs. A small snippet is shown below:

 def forward(self, x, dep=None, target=None):
        # Backbone: ResNet18, x is image size: (512, 1536)

Here, targets are the labels so to say.

In this case, how do I go about finding the best learning rate using lr-finder? Notably, I can only use batch_size=2 because of the computational limitations.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:1
  • Comments:21 (11 by maintainers)

github_iconTop GitHub Comments

1reaction
YashRunwalcommented, Aug 20, 2021

Hi,

OKay, I will try to create the wrapper. I have to make some changes to the model as well I think. Currently, the model returns losses, and now for this wrapper, it should return the prediction heads which can be passed to this wrapper to calculate losses. I will do this and get back to you tomorrow I think.

You are right, one loss (txty loss) dominates. How do I add some weights to the losses though?

1reaction
YashRunwalcommented, Aug 20, 2021

I have used gradient accumulation. I backpropagated the gradients after 64 steps (simulating 64 batch size). But let me check out how to use lr-finder with this. I will get back to you in case I need any help. Thanks for replying promptly. I really appreciate it.

Read more comments on GitHub >

github_iconTop Results From Across the Web

LR Finder Using PyTorch - Kaggle
This Python 3 environment comes with many helpful analytics libraries installed ... batch_size = inputs.size(0) # Forward pass and loss computation outputs ......
Read more >
Speeding up Neural Net Training with LR-Finder
Basic objective of a LR Finder is to find the highest LR which still minimises the loss and does not make the loss...
Read more >
The Learning Rate Finder Technique: How Reliable Is It?
Indeed, one of the many challenges in training deep neural ... we can produce similar results to fast.ai in the LRFinder experiments, ...
Read more >
Keras: Multiple outputs and multiple losses - PyImageSearch
Learn how to use multiple fully-connected heads and multiple loss functions to create a multi-output deep neural network using Python, ...
Read more >
How Do You Find A Good Learning Rate - Sylvain Gugger
Worse, a high learning rate could lead you to an increasing loss until ... get the loss for this mini-batch of inputs/outputs inputs,labels ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found