Skip to content

Understanding difference of Lightning attention2 's performance between training and inference #24

@Arthur-wza

Description

@Arthur-wza

Dear authors, Hello!
First off, thank you so much for your excellent work on this project!
While studying and using your test interface, I've observed a phenomenon that I'd love to get your confirmation or deeper insights on.
When processing sequences of the same length, the model's backward pass speed is significantly slower than its forward pass speed.

--- Benchmark for B=6, H=8, N=8192, D=128, E=64, dtype=torch.bfloat16 ---
Lightning Attention 2 - Forward: 0.866 ms, Backward: 11239.307 ms
Mamba Layer - Forward: 7.370 ms, Backward: 107.192 ms
while comparing with mamba layer.

Have you ever faced the same phenomenon?Could plz tell me why?

here is my code for test:
import pytest
import torch
import time

from lightning_attn.ops import lightning_attn2, linear_attn
from lightning_attn.utils import _build_slope_tensor

from mamba_ssm import Mamba

def get_params():
array = [
# (batch_size, num_heads, seq_len, head_dim_q_k, head_dim_v)
(6, 8, 2048, 128, 64),
(6, 8, 4096, 128, 64),
(6, 8, 8192, 128, 64),
(6, 16, 2048, 64, 64),
]
return array

@pytest.mark.parametrize("b, h, n, d, e", get_params())
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_lightning2_and_mamba_benchmark(b, h, n, d, e, dtype):

torch.manual_seed(2024)
device = torch.device("cuda")

q = (torch.randn((b, h, n, d), dtype=dtype, device=device) / 10).requires_grad_()
k = (torch.randn((b, h, n, d), dtype=dtype, device=device) / 10).requires_grad_()
v = (torch.randn((b, h, n, e), dtype=dtype, device=device) / 10).requires_grad_()
do = torch.randn((b, h, n, e), dtype=dtype, device=device) / 10
s = _build_slope_tensor(h).to(q.device).to(torch.float32)

d_model = h * d
# (B, H, N, D) -> (B, N, H, D) -> (B, N, H*D)
mamba_input = q.permute(0, 2, 1, 3).reshape(b, n, d_model).clone().requires_grad_()
do_mamba = torch.randn((b, n, d_model), dtype=dtype, device=device) / 10

mamba_layer = Mamba(
    d_model=d_model,
    d_state=16,  # a common setting for d_state
    d_conv=4,    # a common setting for d_conv
    expand=2,    # a common setting for expand
    device=device,
    dtype=dtype
)


start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

for _ in range(5):
    _ = lightning_attn2(q, k, v, s)

start_event.record()
o = lightning_attn2(q, k, v, s)
end_event.record()
torch.cuda.synchronize()
la2_fwd_time = start_event.elapsed_time(end_event)

start_event.record()
o.backward(do, retain_graph=True)
end_event.record()
torch.cuda.synchronize()
la2_bwd_time = start_event.elapsed_time(end_event)

dq_la, q.grad = q.grad.clone(), None
dk_la, k.grad = k.grad.clone(), None
dv_la, v.grad = v.grad.clone(), None

for _ in range(5):
    _ = mamba_layer(mamba_input)
    
start_event.record()
output_mamba = mamba_layer(mamba_input)
end_event.record()
torch.cuda.synchronize()
mamba_fwd_time = start_event.elapsed_time(end_event)

start_event.record()
output_mamba.backward(do_mamba)
end_event.record()
torch.cuda.synchronize()
mamba_bwd_time = start_event.elapsed_time(end_event)

print(f"\n--- Benchmark for B={b}, H={h}, N={n}, D={d}, E={e}, dtype={dtype} ---")
print(f"Lightning Attention 2 - Forward: {la2_fwd_time:.3f} ms, Backward: {la2_bwd_time:.3f} ms")
print(f"Mamba Layer           - Forward: {mamba_fwd_time:.3f} ms, Backward: {mamba_bwd_time:.3f} ms")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions