Skip to content

Difference between triton and torch Implementation: #23

@wanghanxiao123

Description

@wanghanxiao123

Hi, when testing I found the result of triton and torch veision are different, any idea?

def test_attention_equivalence():
    torch.manual_seed(42)
    bsz, num_heads, seq_len, head_dim = 1, 1, 1, 32
    q = torch.randn(bsz, num_heads, seq_len, head_dim)
    k = torch.randn(bsz, num_heads, seq_len, head_dim)
    v = torch.randn(bsz, num_heads, seq_len, head_dim)
    s = _build_slope_tensor(num_heads)


    def run_impl(func, q, k, v, s):
            return func(q, k, v, s)

        
    outputs = {
        "原始矩阵乘法": run_impl(lambda q, k, v, s: linear_attn(q, k, v, s), q.cuda(), k.cuda(), v.cuda(), s.cuda()),
        "Triton实现": run_impl(lambda q, k, v, s: lightning_attn_func(q, k, v, s), q.cuda(), k.cuda(), v.cuda(), s.cuda()),
        "递归实现": run_impl(lambda q, k, v, s: lightning_attn_recursive(q, k, v, s), q, k, v, s),
        # "Python分块实现": run_impl(lambda q, k, v, s: lightning_attn_python(q, k, v, s), q, k, v, s)
    }

    # 对比结果
    baseline = outputs["原始矩阵乘法"]
    for name, out in outputs.items():
        diff = torch.max(torch.abs(out.cpu() - baseline.cpu())).item()
        print(f"{name:15} 最大误差: {diff:.6f}{' ✅' if diff < 1e-4 else ' ❌'}")

原始矩阵乘法 最大误差: 0.000000 ✅
Triton实现 最大误差: 0.008380 ❌
递归实现 最大误差: 0.000003 ✅

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions