Skip to content

Does the lightning_attn interface really trainable? #25

@Arthur-wza

Description

@Arthur-wza

Dear authors:
when I use lightning_attn_func to replace LlamaAttention in Llama model, I find the convergence of model not really good.
What's more, I don't find a model in this github project, could you plz tell me your experience of training with lightning_attn_func?
Thank you very much!

here are my calling code, if there is any mistake plz let me know.
from lightning_attn.ops import lightning_attn_func
from lightning_attn.utils import _build_slope_tensor
class LightningAttention(nn.Module):
def init(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().init()
...........
def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

   ..............
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    s = _build_slope_tensor(self.config.num_attention_heads).to(query_states.device).to(torch.float32)

    attn_output = lightning_attn_func(
        query_states,
        key_states,
        value_states,
        s
    )

    attn_output = attn_output.transpose(1, 2).contiguous()

    attn_output = attn_output.reshape(bsz, q_len, -1)

    if self.config.pretraining_tp > 1:
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
    else:
        attn_output = self.o_proj(attn_output)

    return attn_output, past_key_value

Thanks again!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions