Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 54 additions & 30 deletions modelopt/torch/quantization/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import warnings
from collections import defaultdict
from collections.abc import Callable, Sequence
from contextlib import nullcontext
from typing import Any

import regex as re
import torch
import torch.distributed
import torch.nn as nn
from tqdm import tqdm

Expand All @@ -41,7 +41,7 @@
from .config import QuantizeConfig, QuantizerAttributeConfig
from .conversion import set_quantizer_by_cfg
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import is_quantized_linear, multi_context
from .utils import is_quantized_linear


def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float:
Expand Down Expand Up @@ -212,7 +212,11 @@ def __init__(
self.active = self.original

self._importance_dict = {
quant_recipe: dict.fromkeys(self.nn_modules, 0.0) for quant_recipe in self.choices
quant_recipe: {
mod: torch.zeros((), device=mod.weight.device, dtype=torch.float32)
for mod in self.nn_modules
}
for quant_recipe in self.choices
}

@property
Expand All @@ -238,11 +242,15 @@ def active(self, val: HPType | None):
def importance(self) -> dict:
"""Return the importance dict mapping recipe and importance."""
return {
quant_recipe: sum(importance_dict.values())
quant_recipe: sum(v.cpu().item() for v in importance_dict.values())
for quant_recipe, importance_dict in self._importance_dict.items()
}


def _add_auto_quantize_score(grad_output, output_diff, score_tensor):
score_tensor += ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum()


class AutoQuantizeSearcher(BaseSearcher):
"""A searcher for AutoQuantize algorithm.

Expand All @@ -261,7 +269,7 @@ class AutoQuantizeSearcher(BaseSearcher):

candidate_stats: dict[str, dict[str, list[float]]]
best: dict[str, Any]
gradient_checkpointing_enable_contexts: list[tuple[Callable, Callable]] = []
custom_support: list[tuple[Callable, Callable, Callable]] = []

rules = [
r"^(.*?)\.(q_proj|k_proj|v_proj)$", # q_proj, k_proj, v_proj for llama like models
Expand Down Expand Up @@ -336,15 +344,19 @@ def _get_search_recipes(quantization_formats):
)

@classmethod
def register_gradient_checkpointing_enable_context(
cls, is_supported_checker: Callable, context: Callable
def register_custom_support(
cls,
is_supported_checker: Callable,
grad_ckpt_context: Callable,
is_param_grad_enabled: Callable,
):
"""Register a gradient checkpointing enable context for `AutoQuantize` score estimation.
"""Register custom support for `AutoQuantize` score estimation.

If the `is_supported_checker(model)` returns True, the `context(model)` will be used to enable gradient
checkpointing.
If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be
used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)`
will be used to enable gradient for the parameter.
"""
cls.gradient_checkpointing_enable_contexts.append((is_supported_checker, context))
cls.custom_support.append((is_supported_checker, grad_ckpt_context, is_param_grad_enabled))

def _get_default_forward_backward_step(self):
def forward_backward_step(model, data):
Expand All @@ -361,7 +373,7 @@ def forward_backward_step(model, data):
return forward_backward_step

@torch.enable_grad()
def _estimate_auto_quantize_scores(self):
def _estimate_auto_quantize_scores(self, is_param_grad_enabled):
# TODO: remove the no-quant recipe
def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
module.quant_recipe = QuantRecipe(quant_cfg=None)
Expand All @@ -377,7 +389,7 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
module.output_diff_dict = {}
with torch.no_grad():
for recipe in module.get_hparam("quant_recipe").choices:
if recipe.compression >= 1.0:
if recipe == QuantRecipe(quant_cfg=None):
continue
module.quant_recipe = recipe
output_diff = module._forward_original(input, *args, **kwargs)
Expand All @@ -392,18 +404,21 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):

def backward_hook(module, grad_input, grad_output):
for recipe, output_diff in module.output_diff_dict.items():
score = ((grad_output[0].float() ** 2) * (output_diff.float() ** 2)).sum()
module.get_hparam("quant_recipe")._importance_dict[recipe][module] += score.item()
module.output_diff_dict[recipe] = None
score_tensor = module.get_hparam("quant_recipe")._importance_dict[recipe][module]
_add_auto_quantize_score(grad_output[0], output_diff, score_tensor)

del module.output_diff_dict

def setup_params_for_score_estimation(name, param, params_metadata):
def setup_params_for_score_estimation(name, param, params_metadata, enable_grad=True):
# Let us delete the gradient as soon as they are computed to save memory
# In addition, this method enables gradient for all parameters
# This is needed to make sure the re-entrant activation checkpointing works
params_metadata[name] = {"requires_grad": param.requires_grad}
param.requires_grad = True
param.requires_grad = enable_grad
if not enable_grad:
return
if self.config.get("verbose", False):
print_rank_0(f"AutoQuantize: Enabling gradient for param {name}.")
accum_grad, handle = create_param_grad_clear_hook(param)
params_metadata[name]["accum_grad"] = accum_grad # We need to keep the accum_grad alive
params_metadata[name]["handle"] = handle
Expand All @@ -421,7 +436,9 @@ def cleanup_module_after_score_estimation(module):

def cleanup_params_after_score_estimation(name, param, params_metadata):
param.requires_grad = params_metadata[name]["requires_grad"]
params_metadata[name]["handle"].remove()
handle = params_metadata[name].get("handle", None)
if handle is not None:
handle.remove()

for name, module in self.model.named_modules():
if (
Expand All @@ -432,10 +449,11 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
setup_module_for_score_estimation(module)

params_metadata = {}

for name, param in self.model.named_parameters():
# TODO: Enabling gradient for all parameters is not needed and making backward slow
# We need to enable gradient only for the the first parameter of the module such as embedding weights
setup_params_for_score_estimation(name, param, params_metadata)
setup_params_for_score_estimation(
name, param, params_metadata, is_param_grad_enabled(name, self.model)
)

gc.collect()
if torch.cuda.is_available():
Expand Down Expand Up @@ -588,14 +606,20 @@ def forward_loop(model):
ModeloptStateManager(self.model).state_dict().pop()

self.model.eval()
with multi_context(
*(
context(self.model)
for is_supported_checker, context in self.gradient_checkpointing_enable_contexts
if is_supported_checker(self.model)
)
):
self._estimate_auto_quantize_scores()

def _default_is_param_grad_enabled(pname, model):
return True

grad_checkpointing_ctxt = None
is_param_grad_enabled = _default_is_param_grad_enabled
for is_supported_checker, ctxt_candidate, grad_enabled_candidate in self.custom_support:
if is_supported_checker(self.model):
grad_checkpointing_ctxt = ctxt_candidate
is_param_grad_enabled = grad_enabled_candidate
break

with grad_checkpointing_ctxt(self.model) if grad_checkpointing_ctxt else nullcontext():
self._estimate_auto_quantize_scores(is_param_grad_enabled)

def run_search(self):
"""Search for the best per-layer quantization configuration and return the best model and configuration.
Expand Down
10 changes: 8 additions & 2 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,14 @@ def setup_model_for_gradient_checkpointing(model: nn.Module):
model.config.use_cache = use_cache


AutoQuantizeSearcher.register_gradient_checkpointing_enable_context(
_is_supported_hf_model, setup_model_for_gradient_checkpointing
def _is_param_grad_enabled_for_auto_quantize(pname, model):
return "embed" in pname


AutoQuantizeSearcher.register_custom_support(
_is_supported_hf_model,
setup_model_for_gradient_checkpointing,
_is_param_grad_enabled_for_auto_quantize,
)

CUSTOM_MODEL_PLUGINS.update(
Expand Down