question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

output of extract_feature is a nan tensor when eval model

See original GitHub issue

Hello, I’m implementing an image captioning model using the EfficientNet model (efficientnet-b5). When I tried to extract features using the EfficientNet model with eval mode, I got a nan tensor.

Following codes generate the model:

image_size = EfficientNet.get_image_size(args.efficientnet)
efficientnet = EfficientNet.from_pretrained(args.efficientnet, advprop=True)
feature_dim, feature_size = efficientnet.extract_features(
    torch.FloatTensor(1, 3, image_size, image_size)).shape[1:3]

model = TransformerCaptioning(voca_num, pad_idx=args.pad_idx, bos_idx=args.bos_idx, 
                eos_idx=args.eos_idx, max_len=args.max_len, feature_dim=feature_dim, feature_size=feature_size, d_model=args.d_model, n_head=args.n_head, dim_feedforward=args.dim_feedforward, num_encoder_layer=args.num_encoder_layer, num_decoder_layer=args.num_decoder_layer, dropout=args.dropout, embedding_dropout=args.embedding_dropout)
model = model.to(device)
model.extractor = efficientnet.to(device)

When I run codes as

model.extractor.eval() # extractor is EfficientNet
imgs, caps = next(iter(train_loader))
imgs = imgs.to(device)

with torch.no_grad():
    test = model.extractor.extract_features(imgs)
    
print(test)

then the nan tensor is returned as

image

However, when I changed the eval mode to train mode, the normal tensor was returned as

model.extractor.train() # extractor is EfficientNet
imgs, caps = next(iter(train_loader))
imgs = imgs.to(device)

with torch.no_grad():
    test = model.extractor.extract_features(imgs)
    
print(test)
image

Do you know why the EfficientNet model returns a nan tensor when eval mode?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6

github_iconTop GitHub Comments

3reactions
Johann-Hubercommented, Oct 8, 2020

Hi @rudvlf0413 , thank you for the quick answer !

I’ve finally figure out where the problem comes from. Actually, it is due to the BatchNorm2d layers, and the way pytorch handles them in train() and eval() mode.

In train() mode, pytorch computes the BN trainable parameters using each batch, as described in the paper. At each iteration, those parameters are updated using the current batch mean and std, as well as the previous batches mean and std (using an exponential moving average).

In eval() mode, there shouldn’t be batch mean & std computing, as inference shouldn’t require an full batch to do prediction. Therefore, pytorch uses the batch mean & std computed during training.

Problem occurs if the distribution of the dataset used for inference is not close enough to the dataset used for the training iterations. In such case, the BN trainable parameters does not reflect the actual batch distribution, leading to overestimated activations. The feature maps values increase exponentially in the network, eventually returning NaN values.

So in a sense, it is related to initialization (dataset & first training iterations in particular). Unfortunately, I was not able to fix the issue by playing with the model initialization, and the way its loaded into the training device. If you remember something about it, I would be interested to know !

An easy way to fix the problem is described in this discussion, and basically consists of setting track_running_stats=False in the BatchNorm2d layers.

As I said above, it is efficient to avoid the validation loss explosion issue, but in a partially satisfying manner. Other solutions proposed in the mentioned discussion did not work for me, I will investigate it a bit further to figure out how this problem is handled in other frameworks.

Thank you again for your help !

1reaction
Kyeongpilcommented, Oct 6, 2020

@Johann-Huber I confuse about how I solved this problem after a long time. However, I might solve this problem by changing the code for the model initialization.

For example, you can change this code

model = TransformerCaptioning(...)
model = model.to(device)
model.extractor = efficientnet.to(device)

to

model = TransformerCaptioning(...)
model.extractor = efficientnet
model = model.to(device)

With some initialization ways, I think the problem occurring NaN values happens.

Read more comments on GitHub >

github_iconTop Results From Across the Web

model.train() and model.eval() causing nan values
I see that resnet.zero_grad() is after logit = resnet(data) , which causes the gradient to explode in your case. Please do it as...
Read more >
How can l load my best model as a feature extractor/evaluator?
Now you can use it to evaluate new samples: new_sample = ... output = model(new_sample). What do you mean by feature extractor?
Read more >
Dealing with NaNs and infs - Stable Baselines - Read the Docs
During the training of a model on a given environment, it is possible that the RL model becomes completely corrupted when a NaN...
Read more >
How to Develop a Deep Learning Photo Caption Generator ...
Deep learning methods have demonstrated state-of-the-art results on ... How to evaluate a train caption generation model and use it to ...
Read more >
python pytorch nan | The AI Search Engine You Control
x = torch.tensor([1, 2, np.nan]) tensor([ 1., 2., nan.]) ... Running the following piece of code gives no nan, but I forced shape...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found