Predicted ESM2 logits depend on other elements within a batch
See original GitHub issueThank you so much for all the work on ESM and ESM2. I ran into some surprising behaviour:
Bug description ESM2 predicts slightly different logits even when in eval mode depending on other elements within a batch.
Reproduction steps
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')
model_path = "facebookresearch/esm:main"
model_name = f"esm2_t36_3B_UR50D"
esm_model, alphabet = torch.hub.load(model_path, model_name)
esm_model = esm_model.eval().cuda()
batch_converter = alphabet.get_batch_converter()
# Those are arbitrary sequences, doesn't matter which ones are used
sequences = [
'A' * 255,
'Y' * 310
]
model_input = batch_converter([(None, seq) for seq in sequences[:2]])[2]
model_input = model_input.to(device)
# Here is the surprising part:
logits1 = esm_model(model_input[[0]])['logits']
logits2 = esm_model(model_input)['logits']
torch.linalg.norm(logits1 - logits2[0])
tensor(0.3426, device='cuda:0')
This gives roughly 0.3426 - with many values significantly different than zero. I was expecting this to be due to some kind of batch norm like functionality, but, the model is in eval mode.
Issue Analytics
- State:
- Created a year ago
- Comments:11 (3 by maintainers)
Top Results From Across the Web
ESM - Hugging Face
ESM-2 outperforms all tested single-sequence protein language models across a range of structure prediction tasks, and enables atomic resolution structure ...
Read more >Prediction depends on batch size - python - Stack Overflow
Method B: I run the predictions in each image separately. Each image is loaded as a 4D array of the shape (1, ROWS,...
Read more >In silico epigenetics of metal exposure and subclinical ... - NCBI
We explored the association of metal levels with subclinical atherosclerosis and epigenetic changes in relevant biological pathways.
Read more >Time Scale Decomposition of Climate and Correction of ...
Bearing in mind these requirements, we propose to decompose atmospheric variables into three temporal elements that represent the climate ...
Read more >spatial quantile regression: Topics by Science.gov
The spatial and temporal stability of model predictions were examined ... since important variables can influence various quantiles in different ways.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
This might be caused by non-deterministic behavior of PyTorch and cuBLAS on GPU. There are a few things you can configure to get as much determinism as possible, but even then you might get different results on different hardware.
https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility https://pytorch.org/docs/stable/notes/randomness.html#reproducibility
Below is a minimal example (with a single linear layer) to test your system. You can uncomment the lines regarding cuDNN and deterministic algorithms to see how the results change. In my personal experience, setting CUBLAS_WORKSPACE_CONFIG to :16:8 gave the most stable results.
@FedericoV Have you got any new results after dropping out pad logits and setting deterministic=True?