-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
Dear authors:
when I use lightning_attn_func to replace LlamaAttention in Llama model, I find the convergence of model not really good.
What's more, I don't find a model in this github project, could you plz tell me your experience of training with lightning_attn_func?
Thank you very much!
here are my calling code, if there is any mistake plz let me know.
from lightning_attn.ops import lightning_attn_func
from lightning_attn.utils import _build_slope_tensor
class LightningAttention(nn.Module):
def init(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().init()
...........
def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
..............
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
s = _build_slope_tensor(self.config.num_attention_heads).to(query_states.device).to(torch.float32)
attn_output = lightning_attn_func(
query_states,
key_states,
value_states,
s
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
return attn_output, past_key_value
Thanks again!
Metadata
Metadata
Assignees
Labels
No labels