24 add cross attention labels text#60
Conversation
- module and config created to do that - mainly attached the TextEmbedder (it aggregates the token embedding to produce a sentence embedding - instead of naive averaging) - rest of the code has been adapted, especially categorical var handling in TextClassificationModel
used as a namespace after, so no converting it throws a bug
- given a parameter, retrieve the attention matrix - compatible with captum attributions - update tests accordingly
There was a problem hiding this comment.
Pull request overview
Adds optional “label attention” cross-attention to the text classification pipeline so the model can produce label-specific sentence embeddings and (optionally) return label×token attention matrices for explainability.
Changes:
- Introduces
LabelAttentionConfig/LabelAttentionClassifierand integrates label attention intoTextEmbedder. - Updates model forward/predict paths to support returning label-attention matrices and adjusts classifier head output shape when label attention is enabled.
- Extends pipeline tests to cover label-attention-enabled training/prediction.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
torchTextClassifiers/torchTextClassifiers.py |
Wires label-attention config through initialization, updates predict() explainability options, and deserializes label-attention config on load. |
torchTextClassifiers/model/model.py |
Propagates label-attention enablement, adjusts forward pass to optionally return attention matrices, and normalizes embeddings before the head. |
torchTextClassifiers/model/lightning.py |
Minor formatting-only change in validation_step. |
torchTextClassifiers/model/components/text_embedder.py |
Adds label-attention config/module and changes embedder outputs to include sentence embeddings + optional attention matrices. |
torchTextClassifiers/model/components/__init__.py |
Exports LabelAttentionConfig. |
tests/test_pipeline.py |
Adds a label-attention-enabled pipeline test and updates explainability assertions for new return keys. |
Comments suppressed due to low confidence (1)
torchTextClassifiers/model/components/text_embedder.py:209
TextEmbedder._get_sentence_embeddingnow sometimes returns a raw tensor (for aggregation_method 'first'/'last'), butTextEmbedder.forwardunconditionally treats the result as a dict and calls.values(). This will crash when aggregation_method != 'mean'. Make_get_sentence_embeddingreturn a consistent structure (e.g., always a dict withsentence_embeddingandlabel_attention_matrix).
if self.attention_config is not None:
if self.attention_config.aggregation_method is not None: # default is "mean"
if self.attention_config.aggregation_method == "first":
return token_embeddings[:, 0, :]
elif self.attention_config.aggregation_method == "last":
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
return token_embeddings[
torch.arange(token_embeddings.size(0)),
lengths - 1,
:,
]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #61, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #62, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
…essages Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #64, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 7 changed files in this pull request and generated 9 comments.
Comments suppressed due to low confidence (1)
torchTextClassifiers/model/components/text_embedder.py:226
- TextEmbedder._get_sentence_embedding returns a raw Tensor for aggregation_method == 'last', but callers now expect a dict. This path will break whenever aggregation_method is set to 'last'. Return {'sentence_embedding': ..., 'label_attention_matrix': None} instead.
elif self.attention_config.aggregation_method == "last":
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
return token_embeddings[
torch.arange(token_embeddings.size(0)),
lengths - 1,
:,
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@meilame-tayebjee I've opened a new pull request, #65, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@meilame-tayebjee I've opened a new pull request, #66, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@meilame-tayebjee I've opened a new pull request, #67, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@meilame-tayebjee I've opened a new pull request, #68, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #69, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
This pull request introduces label attention as an optional feature in the text classification pipeline, allowing the model to generate label-specific sentence embeddings using a cross-attention mechanism. The changes include new configuration classes, updates to the
TextEmbedderand model logic, and new tests to ensure label attention works as intended.Label Attention Mechanism:
LabelAttentionConfigandLabelAttentionClassifierto enable label-specific sentence embeddings using cross-attention, where labels act as queries over token embeddings. [1] [2]TextEmbedderConfigandTextEmbedderto support label attention, including a new output structure and logic to handle label attention matrices. [1] [2] [3] [4] [5] [6] [7]Model and Pipeline Integration:
model.py) to validate and propagate label attention configuration, including enforcing that the classification head outputs a single value when label attention is enabled and updating the number of classes accordingly.Testing Enhancements:
test_label_attention_enabledand corresponding updates to the helper functions. [1] [2] [3] [4] [5] [6] [7]Miscellaneous: