question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`norm='batch'` in UNet cause int64 weights

See original GitHub issue

Describe the bug When set the norm of UNet to batch, it will cause “model.0.conv.unit0.conv.weight” to be float64 data. While instance norm is fine.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
holgerrothcommented, May 7, 2021

Batchnorm seems to create LongTensors which causes the issue. See this simple test script comparing UNet with instance norm vs. batch norm:

import numpy as np
from monai.networks.nets import UNet


def count_long_tensors(net):
    state_dict = net.state_dict()

    n_long_tensors = 0
    for key in state_dict.keys():
        if state_dict[key].type() == "torch.LongTensor":
            print(f"{key}: {state_dict[key].type()}, {np.shape(state_dict[key])}")
            n_long_tensors += 1

    return n_long_tensors


unet_instance = UNet(dimensions=3,
                       in_channels=1,
                       out_channels=2,
                       channels=[16, 32, 64, 128, 256],
                       strides=[2, 2, 2, 2],
                       num_res_units=2)
unet_batch = UNet(dimensions=3,
                   in_channels=1,
                   out_channels=2,
                   channels=[16, 32, 64, 128, 256],
                   strides=[2, 2, 2, 2],
                   num_res_units=2,
                   norm="batch")

print("Long tensors:")
print("Instance Norm Unet", count_long_tensors(unet_instance))
print("Batch Norm Unet", count_long_tensors(unet_batch))

Output:

Long tensors:
Instance Norm Unet 0
model.0.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.0.conv.unit1.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.0.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.0.conv.unit1.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.0.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.0.conv.unit1.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.1.submodule.2.0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.2.0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.2.0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.1.submodule.2.1.conv.unit0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
model.2.0.adn.N.num_batches_tracked: torch.LongTensor, torch.Size([])
Batch Norm Unet 17
0reactions
Nic-Macommented, May 8, 2021

Hi @holgerroth ,

Actually, this int64 variable comes from PyTorch source code of batch norm: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L54 You can removed by below settings in your UNet args:

unet_batch = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=[16, 32, 64, 128, 256],
    strides=[2, 2, 2, 2],
    num_res_units=2,
    norm=("batch", {"track_running_stats": False}),
)

I already verified locally, will also enhance our unit tests to cover it. Could you please help double confirm?

Thanks.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Normalization Techniques in Deep Neural Networks - Medium
We are going to study Batch Norm, Weight Norm, Layer Norm, Instance Norm, Group Norm, Batch-Instance Norm, Switchable Norm. Let's start with the...
Read more >
Batch Norm Explained Visually — How it works, and why ...
We end up making a larger update to one weight due to its large gradient. This causes the gradient descent to bounce to...
Read more >
L1-Norm Batch Normalization for Efficient Training of ... - arXiv
norm of the incoming weights to normalize the summed inputs to a neuron. ... However, the BN layer usually causes considerable training.
Read more >
Group Norm, Batch Norm, Instance Norm, which is better
From the curves of the original papers, we can conclude: BN layers lead to faster convergence and higher accuracy. BN layers allow higher ......
Read more >
A Gentle Introduction to Batch Normalization for Deep Neural ...
This can cause the learning algorithm to forever chase a moving target. ... For example, the weights of a layer are updated given...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found