-
Notifications
You must be signed in to change notification settings - Fork 27
Description
Dear authors, Hello!
First off, thank you so much for your excellent work on this project!
While studying and using your test interface, I've observed a phenomenon that I'd love to get your confirmation or deeper insights on.
When processing sequences of the same length, the model's backward pass speed is significantly slower than its forward pass speed.
--- Benchmark for B=6, H=8, N=8192, D=128, E=64, dtype=torch.bfloat16 ---
Lightning Attention 2 - Forward: 0.866 ms, Backward: 11239.307 ms
Mamba Layer - Forward: 7.370 ms, Backward: 107.192 ms
while comparing with mamba layer.
Have you ever faced the same phenomenon?Could plz tell me why?
here is my code for test:
import pytest
import torch
import time
from lightning_attn.ops import lightning_attn2, linear_attn
from lightning_attn.utils import _build_slope_tensor
from mamba_ssm import Mamba
def get_params():
array = [
# (batch_size, num_heads, seq_len, head_dim_q_k, head_dim_v)
(6, 8, 2048, 128, 64),
(6, 8, 4096, 128, 64),
(6, 8, 8192, 128, 64),
(6, 16, 2048, 64, 64),
]
return array
@pytest.mark.parametrize("b, h, n, d, e", get_params())
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_lightning2_and_mamba_benchmark(b, h, n, d, e, dtype):
torch.manual_seed(2024)
device = torch.device("cuda")
q = (torch.randn((b, h, n, d), dtype=dtype, device=device) / 10).requires_grad_()
k = (torch.randn((b, h, n, d), dtype=dtype, device=device) / 10).requires_grad_()
v = (torch.randn((b, h, n, e), dtype=dtype, device=device) / 10).requires_grad_()
do = torch.randn((b, h, n, e), dtype=dtype, device=device) / 10
s = _build_slope_tensor(h).to(q.device).to(torch.float32)
d_model = h * d
# (B, H, N, D) -> (B, N, H, D) -> (B, N, H*D)
mamba_input = q.permute(0, 2, 1, 3).reshape(b, n, d_model).clone().requires_grad_()
do_mamba = torch.randn((b, n, d_model), dtype=dtype, device=device) / 10
mamba_layer = Mamba(
d_model=d_model,
d_state=16, # a common setting for d_state
d_conv=4, # a common setting for d_conv
expand=2, # a common setting for expand
device=device,
dtype=dtype
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for _ in range(5):
_ = lightning_attn2(q, k, v, s)
start_event.record()
o = lightning_attn2(q, k, v, s)
end_event.record()
torch.cuda.synchronize()
la2_fwd_time = start_event.elapsed_time(end_event)
start_event.record()
o.backward(do, retain_graph=True)
end_event.record()
torch.cuda.synchronize()
la2_bwd_time = start_event.elapsed_time(end_event)
dq_la, q.grad = q.grad.clone(), None
dk_la, k.grad = k.grad.clone(), None
dv_la, v.grad = v.grad.clone(), None
for _ in range(5):
_ = mamba_layer(mamba_input)
start_event.record()
output_mamba = mamba_layer(mamba_input)
end_event.record()
torch.cuda.synchronize()
mamba_fwd_time = start_event.elapsed_time(end_event)
start_event.record()
output_mamba.backward(do_mamba)
end_event.record()
torch.cuda.synchronize()
mamba_bwd_time = start_event.elapsed_time(end_event)
print(f"\n--- Benchmark for B={b}, H={h}, N={n}, D={d}, E={e}, dtype={dtype} ---")
print(f"Lightning Attention 2 - Forward: {la2_fwd_time:.3f} ms, Backward: {la2_bwd_time:.3f} ms")
print(f"Mamba Layer - Forward: {mamba_fwd_time:.3f} ms, Backward: {mamba_bwd_time:.3f} ms")