question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

FlashAttention returns all zeros when device is 'cuda:1'

See original GitHub issue

testing code:

import numpy as np
import torch
from flash_attn.flash_attention import FlashAttention
def test(device):
    flash = FlashAttention()
    d_head = 64
    n_heads = 32
    flash.softmax_scale = 1  # / (d_head ** 0.5)
    batch_size = 4
    seq_len = 16
    qkv = torch.ones(batch_size, seq_len, 3, n_heads, d_head, dtype=torch.float16)
    flash = flash.to(device)
    qkv = qkv.to(device)
    out, _ = flash(qkv)
    print(out.shape, torch.abs(out).type(torch.float32).sum())
    return out

if I add the following code and run it,

out = test("cuda:0")

the output is

torch.Size([4, 16, 32, 64]) tensor(131072., device='cuda:0')

if I run this instead

out = test("cuda:1")

the output is

torch.Size([4, 16, 32, 64]) tensor(0., device='cuda:1')

I’ve verified it’s not the hardware issue by adding

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

at the top and it works with cuda:0

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:1
  • Comments:9 (5 by maintainers)

github_iconTop GitHub Comments

3reactions
tridaocommented, Oct 16, 2022

I’ve just pushed a commit that fixed this. Let me know if it works on your side.

1reaction
geekinglcqcommented, Oct 21, 2022

Great! @geekinglcq Did recompilation fix it or do you need to set torch.cuda.set_device(qkv.device())?

I just recompiled it.

Read more comments on GitHub >

github_iconTop Results From Across the Web

HazyResearch/flash-attention: Fast and memory ... - GitHub
This repository provides the official implementation of FlashAttention from the following paper. FlashAttention: Fast and Memory-Efficient Exact Attention with ...
Read more >
Why there are some data in cuda:0 when I indicate cuda:1 to ...
When training, the CPU is basically preparing the data (through the DataLoader) to be fed to the GPUs. It's likely that GPUs wait...
Read more >
Fast and Memory-Efficient Exact Attention with IO-Awareness
We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU ...
Read more >
FlashAttention: Fast and Memory-Efficient Exact Attention with ...
We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU ...
Read more >
Fast and Memory-Efficient Exact Attention with IO-Awareness
We analyze the IO complexity of FlashAttention, showing that it requires ... valuable problems in machine learning in the past 5 years, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found