question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Predicted ESM2 logits depend on other elements within a batch

See original GitHub issue

Thank you so much for all the work on ESM and ESM2. I ran into some surprising behaviour:

Bug description ESM2 predicts slightly different logits even when in eval mode depending on other elements within a batch.

Reproduction steps

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

model_path = "facebookresearch/esm:main"
model_name = f"esm2_t36_3B_UR50D"
esm_model, alphabet = torch.hub.load(model_path, model_name)
    
esm_model = esm_model.eval().cuda()
batch_converter = alphabet.get_batch_converter()

# Those are arbitrary sequences, doesn't matter which ones are used
sequences = [
    'A' * 255,
    'Y' * 310
]

model_input = batch_converter([(None, seq) for seq in sequences[:2]])[2]
model_input = model_input.to(device)

# Here is the surprising part:
logits1 = esm_model(model_input[[0]])['logits']
logits2 = esm_model(model_input)['logits']

torch.linalg.norm(logits1 - logits2[0])

tensor(0.3426, device='cuda:0')

This gives roughly 0.3426 - with many values significantly different than zero. I was expecting this to be due to some kind of batch norm like functionality, but, the model is in eval mode.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:11 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
BernhoferMcommented, Sep 3, 2022

This might be caused by non-deterministic behavior of PyTorch and cuBLAS on GPU. There are a few things you can configure to get as much determinism as possible, but even then you might get different results on different hardware.

https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility https://pytorch.org/docs/stable/notes/randomness.html#reproducibility

Below is a minimal example (with a single linear layer) to test your system. You can uncomment the lines regarding cuDNN and deterministic algorithms to see how the results change. In my personal experience, setting CUBLAS_WORKSPACE_CONFIG to :16:8 gave the most stable results.

import torch
import numpy
import random


seed = 101

random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)

'''
Set EnvVar 'CUBLAS_WORKSPACE_CONFIG' to either ':16:8' or ':4096:8'
'''
# torch.use_deterministic_algorithms(True)

# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

B, N, F = 64, 200, 512

f = torch.nn.Linear(F, 64).cuda()
x = torch.randn(B, N, F).float().cuda()

print('### TEST B DIMENSION ###')
for n in range(1, B+1):
    y = f(x)[:n]
    z = f(x[:n])

    print(n, torch.equal(y, z), torch.abs(y - z).max().item())

print('### TEST N DIMENSION ###')
for n in range(1, N+1):
    y = f(x)[:, :n, :]
    z = f(x[:, :n, :])

    print(n, torch.equal(y, z), torch.abs(y - z).max().item())
0reactions
nikitos9000commented, Sep 19, 2022

@FedericoV Have you got any new results after dropping out pad logits and setting deterministic=True?

Read more comments on GitHub >

github_iconTop Results From Across the Web

ESM - Hugging Face
ESM-2 outperforms all tested single-sequence protein language models across a range of structure prediction tasks, and enables atomic resolution structure ...
Read more >
Prediction depends on batch size - python - Stack Overflow
Method B: I run the predictions in each image separately. Each image is loaded as a 4D array of the shape (1, ROWS,...
Read more >
In silico epigenetics of metal exposure and subclinical ... - NCBI
We explored the association of metal levels with subclinical atherosclerosis and epigenetic changes in relevant biological pathways.
Read more >
Time Scale Decomposition of Climate and Correction of ...
Bearing in mind these requirements, we propose to decompose atmospheric variables into three temporal elements that represent the climate ...
Read more >
spatial quantile regression: Topics by Science.gov
The spatial and temporal stability of model predictions were examined ... since important variables can influence various quantiles in different ways.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found