Skip to content

Commit c254c19

Browse files
HDCharleskylesayrs
andauthored
Fixing untie to be used only as needed and automatic (#1963)
when models have shared in/out embedding, we had to call the untie function to separate them before modifying those layers. This caused increased memory and was applied at all times regardless of whether those layers were targeted for modification. change: automatically detect when transform or mixin needs to untie shared embeddings. This also adds try except to the untieing code so that if it is invoked on a model that can't be untied, it gives a warning rather than erroring new tests are added to test this functionality, old tests are modified to use the automatic untieing the new tests were initially written using claude-code, I then rewrote them TEST PLAN: pytest tests/llmcompressor/modifiers/quantization/test_handling_shared_embeddings.py --------- Signed-off-by: HDCharles <[email protected]> Co-authored-by: Kyle Sayers <[email protected]>
1 parent 8d366bd commit c254c19

File tree

10 files changed

+422
-35
lines changed

10 files changed

+422
-35
lines changed

src/llmcompressor/args/model_arguments.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ class ModelArguments:
6464
)
6565

6666
tie_word_embeddings: bool = field(
67-
default=False,
67+
default=True,
6868
metadata={
6969
"help": "Whether the model's input and output word embeddings "
70-
"should be tied. Note that this is only relevant if the "
70+
"should attempt to be left tied. False means always untie."
71+
" Note that this is only relevant if the "
7172
"model has a output word embedding layer."
7273
},
7374
)

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def oneshot(
233233
processor: Optional[Union[str, ProcessorMixin]] = None,
234234
use_auth_token: bool = False,
235235
precision: str = "auto",
236-
tie_word_embeddings: bool = False,
236+
tie_word_embeddings: bool = True,
237237
trust_remote_code_model: bool = False,
238238
save_compressed: bool = True,
239239
model_revision: str = "main",
@@ -282,7 +282,7 @@ def oneshot(
282282
models.
283283
:param precision: Precision to cast model weights to, default to auto.
284284
:param tie_word_embeddings: Whether the model's input and output word embeddings
285-
should be tied.
285+
should be left tied if possible. False means always untie.
286286
:param trust_remote_code_model: Whether to allow for custom models to execute
287287
their own modeling files.
288288
:param save_compressed: Whether to compress sparse models during save.

src/llmcompressor/entrypoints/utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def pre_process(
5959
Raises:
6060
FileNotFoundError: If the model or processor path is invalid.
6161
"""
62-
_warn_tied_embeddings(model_args.tie_word_embeddings)
6362

6463
# Initialize model
6564
if isinstance(model_args.model, (str, PosixPath)):
@@ -150,21 +149,6 @@ def post_process(
150149
reset_session()
151150

152151

153-
def _warn_tied_embeddings(tie_word_embeddings: bool = False):
154-
"""
155-
Logs a warning if the model has tied word embeddings.
156-
The `tie_word_embeddings` flag may cause issues during saving in the one-shot
157-
calibration workflow due to shared tensor addresses.
158-
"""
159-
if tie_word_embeddings:
160-
logger.debug(
161-
"The tie_word_embeddings flag is by default set to False. "
162-
"This guarantees that the one-shot algorithm saves the final "
163-
"weights without errors. Detected tie_word_embeddings=True. "
164-
"This may cause issues with the one-shot algorithm on save."
165-
)
166-
167-
168152
def initialize_model_from_path(
169153
model_args: ModelArguments,
170154
training_args: Optional[TrainingArguments] = None,

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
reset_quantization_status,
3535
)
3636
from llmcompressor.modifiers.utils.hooks import HooksMixin
37+
from llmcompressor.transformers.compression.compressed_tensors_utils import (
38+
untie_if_target_shared_embedding,
39+
)
3740

3841
__all__ = ["QuantizationMixin"]
3942

@@ -179,6 +182,12 @@ def start_calibration(self, model: torch.nn.Module):
179182
180183
:param model: model to prepare for calibration
181184
"""
185+
186+
matched_module_generator = (
187+
x[1] for x in match_named_modules(model, self.resolved_targets, self.ignore)
188+
)
189+
untie_if_target_shared_embedding(model, matched_module_generator)
190+
182191
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
183192
self._initialize_observers(module)
184193
self._calibration_hooks |= self._initialize_hooks(module)

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
TransformScheme,
88
apply_transform_config,
99
)
10-
from compressed_tensors.utils import TorchDtype
10+
from compressed_tensors.utils import TorchDtype, match_named_modules
1111
from pydantic import Field, ValidationInfo, field_validator
1212

1313
from llmcompressor.core import Event, EventType, State
1414
from llmcompressor.modifiers import Modifier
15+
from llmcompressor.transformers.compression.compressed_tensors_utils import (
16+
untie_if_target_shared_embedding,
17+
)
1518

1619
__all__ = ["QuIPModifier"]
1720

@@ -100,6 +103,16 @@ def on_initialize(self, state: State, **kwargs) -> bool:
100103
def on_start(self, state: State, event: Event, **kwargs):
101104
self.started_ = True
102105

106+
def matched_module_generator():
107+
for scheme in self.transform_config.config_groups.values():
108+
for arg in scheme.apply:
109+
gen = match_named_modules(state.model, arg.targets, arg.ignore)
110+
for _, module in gen:
111+
yield module
112+
113+
# Untie embeddings if they will be targeted by transforms
114+
untie_if_target_shared_embedding(state.model, matched_module_generator())
115+
103116
apply_transform_config(state.model, self.transform_config)
104117

105118
def on_event(self, state: State, event: Event, **kwargs):

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from llmcompressor.core import Event, EventType, State
1717
from llmcompressor.modeling import center_embeddings, fuse_norm_linears
1818
from llmcompressor.modifiers import Modifier
19+
from llmcompressor.transformers.compression.compressed_tensors_utils import (
20+
untie_word_embeddings,
21+
)
1922

2023
from .mappings import SpinQuantMapping, infer_mapping_from_model
2124
from .norm_mappings import NormMapping, infer_norm_mapping_from_model
@@ -148,6 +151,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
148151
def on_start(self, state: State, event: Event, **kwargs):
149152
self.started_ = True
150153

154+
# needed any time embeddings/lm_head is modified
155+
untie_word_embeddings(state.model)
151156
# needs to happen after the model has been hooked to execute on the GPU
152157
# otherwise we're applying weight transforms on CPU
153158
self._center_embeddings(state.model)

src/llmcompressor/transformers/compression/compressed_tensors_utils.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import weakref
3+
from collections.abc import Generator
34
from functools import wraps
45
from typing import Optional
56

@@ -126,8 +127,15 @@ def untie_word_embeddings(model: PreTrainedModel):
126127
127128
:param model: model to fix
128129
"""
129-
input_embed = model.get_input_embeddings()
130-
output_embed = model.get_output_embeddings()
130+
try:
131+
input_embed = model.get_input_embeddings()
132+
output_embed = model.get_output_embeddings()
133+
except NotImplementedError as e:
134+
logger.warning(
135+
f"cannot untie model of type {model.__class__} which doesn't have "
136+
f"get_input_embeddings and get_output_embeddings implmented\n{e}"
137+
)
138+
return
131139

132140
for module in (input_embed, output_embed):
133141
if module is None or not hasattr(module, "weight"):
@@ -149,6 +157,80 @@ def untie_word_embeddings(model: PreTrainedModel):
149157
model.config.tie_word_embeddings = False
150158

151159

160+
def _get_embeddings_or_warn(
161+
model: torch.nn.Module,
162+
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
163+
if not (
164+
hasattr(model, "get_input_embeddings")
165+
and hasattr(model, "get_output_embeddings")
166+
):
167+
logger.warning(
168+
f"{model.__class__} doesn't have attribute get_input_embeddings and"
169+
" get_output_embeddings implemented."
170+
"\nThis can cause"
171+
" problems when quantizing layers with shared weights"
172+
)
173+
return None, None
174+
175+
try:
176+
input_embeddings, output_embeddings = (
177+
model.get_input_embeddings(),
178+
model.get_output_embeddings(),
179+
)
180+
except NotImplementedError as e:
181+
logger.warning(
182+
f"{model.__class__} doesn't have get_input_embeddings and "
183+
"get_output_embeddings implemented."
184+
"\nThis can cause"
185+
" problems when quantizing layers with shared weights"
186+
f"\n{e}"
187+
)
188+
return None, None
189+
190+
if not (
191+
isinstance(input_embeddings, torch.nn.Module)
192+
and isinstance(output_embeddings, torch.nn.Module)
193+
):
194+
logger.warning(
195+
f"expected modules from {model.__class__} get_input_embeddings and"
196+
f" get_output_embeddings but got {type(input_embeddings)}"
197+
f" and {type(output_embeddings)}."
198+
"\nThis can cause"
199+
" problems when quantizing layers with shared weights"
200+
)
201+
return None, None
202+
return input_embeddings, output_embeddings
203+
204+
205+
def untie_if_target_shared_embedding(
206+
model: torch.nn.Module, matched_module_generator: Generator[torch.nn.Module]
207+
):
208+
"""
209+
Helper method that checks for shared input/output embedding and unties them
210+
if either shows up in the matched_module_generator
211+
212+
:param model: model to untie if embeddings are shared and targeted by
213+
matched_module_generator
214+
:param matched_module_generator: Generator of all modules (not names) which
215+
will be modified by quantization or transformation
216+
"""
217+
input_embeddings, output_embeddings = _get_embeddings_or_warn(model)
218+
219+
if None in (input_embeddings, output_embeddings): # if couldn't find embeddings
220+
return
221+
222+
if (
223+
input_embeddings.weight is not output_embeddings.weight
224+
): # if not shared, can ignore
225+
return
226+
227+
# if shared, check if either is targeted
228+
for module in matched_module_generator:
229+
if module in (input_embeddings, output_embeddings):
230+
untie_word_embeddings(model)
231+
return
232+
233+
152234
def get_model_compressor(
153235
model: torch.nn.Module,
154236
sparsity_config: Optional[SparsityCompressionConfig] = None,

0 commit comments

Comments
 (0)