-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
Hi, when testing I found the result of triton and torch veision are different, any idea?
def test_attention_equivalence():
torch.manual_seed(42)
bsz, num_heads, seq_len, head_dim = 1, 1, 1, 32
q = torch.randn(bsz, num_heads, seq_len, head_dim)
k = torch.randn(bsz, num_heads, seq_len, head_dim)
v = torch.randn(bsz, num_heads, seq_len, head_dim)
s = _build_slope_tensor(num_heads)
def run_impl(func, q, k, v, s):
return func(q, k, v, s)
outputs = {
"原始矩阵乘法": run_impl(lambda q, k, v, s: linear_attn(q, k, v, s), q.cuda(), k.cuda(), v.cuda(), s.cuda()),
"Triton实现": run_impl(lambda q, k, v, s: lightning_attn_func(q, k, v, s), q.cuda(), k.cuda(), v.cuda(), s.cuda()),
"递归实现": run_impl(lambda q, k, v, s: lightning_attn_recursive(q, k, v, s), q, k, v, s),
# "Python分块实现": run_impl(lambda q, k, v, s: lightning_attn_python(q, k, v, s), q, k, v, s)
}
# 对比结果
baseline = outputs["原始矩阵乘法"]
for name, out in outputs.items():
diff = torch.max(torch.abs(out.cpu() - baseline.cpu())).item()
print(f"{name:15} 最大误差: {diff:.6f}{' ✅' if diff < 1e-4 else ' ❌'}")
原始矩阵乘法 最大误差: 0.000000 ✅
Triton实现 最大误差: 0.008380 ❌
递归实现 最大误差: 0.000003 ✅
Metadata
Metadata
Assignees
Labels
No labels