FlashAttention returns all zeros when device is 'cuda:1'
See original GitHub issuetesting code:
import numpy as np
import torch
from flash_attn.flash_attention import FlashAttention
def test(device):
flash = FlashAttention()
d_head = 64
n_heads = 32
flash.softmax_scale = 1 # / (d_head ** 0.5)
batch_size = 4
seq_len = 16
qkv = torch.ones(batch_size, seq_len, 3, n_heads, d_head, dtype=torch.float16)
flash = flash.to(device)
qkv = qkv.to(device)
out, _ = flash(qkv)
print(out.shape, torch.abs(out).type(torch.float32).sum())
return out
if I add the following code and run it,
out = test("cuda:0")
the output is
torch.Size([4, 16, 32, 64]) tensor(131072., device='cuda:0')
if I run this instead
out = test("cuda:1")
the output is
torch.Size([4, 16, 32, 64]) tensor(0., device='cuda:1')
I’ve verified it’s not the hardware issue by adding
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
at the top and it works with cuda:0
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:9 (5 by maintainers)
Top Results From Across the Web
HazyResearch/flash-attention: Fast and memory ... - GitHub
This repository provides the official implementation of FlashAttention from the following paper. FlashAttention: Fast and Memory-Efficient Exact Attention with ...
Read more >Why there are some data in cuda:0 when I indicate cuda:1 to ...
When training, the CPU is basically preparing the data (through the DataLoader) to be fed to the GPUs. It's likely that GPUs wait...
Read more >Fast and Memory-Efficient Exact Attention with IO-Awareness
We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU ...
Read more >FlashAttention: Fast and Memory-Efficient Exact Attention with ...
We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU ...
Read more >Fast and Memory-Efficient Exact Attention with IO-Awareness
We analyze the IO complexity of FlashAttention, showing that it requires ... valuable problems in machine learning in the past 5 years, ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

I’ve just pushed a commit that fixed this. Let me know if it works on your side.
I just recompiled it.