| 
 | 1 | +import contextlib  | 
 | 2 | +from functools import partial  | 
 | 3 | + | 
 | 4 | +import pytest  | 
 | 5 | +import torch  | 
 | 6 | +from transformers import AutoModelForCausalLM  | 
 | 7 | + | 
 | 8 | +from llmcompressor.modeling.moe_context import moe_calibration_context  | 
 | 9 | +from llmcompressor.modeling.qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock  | 
 | 10 | +from llmcompressor.utils.dev import skip_weights_download  | 
 | 11 | +from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context  | 
 | 12 | +from tests.testing_utils import requires_cadence, requires_gpu  | 
 | 13 | + | 
 | 14 | + | 
 | 15 | +@requires_cadence("weekly")  | 
 | 16 | +@pytest.mark.parametrize("model_stub", ["Qwen/Qwen3-Next-80B-A3B-Instruct"])  | 
 | 17 | +def test_calib_replace_qwen3moe_all_experts(model_stub):  | 
 | 18 | +    with skip_weights_download():  | 
 | 19 | +        model = AutoModelForCausalLM.from_pretrained(model_stub)  | 
 | 20 | + | 
 | 21 | +    # Qwen3MoE layer replacement is temporary within the context  | 
 | 22 | +    with contextlib.ExitStack() as stack:  | 
 | 23 | +        stack.enter_context(calibration_forward_context(model))  | 
 | 24 | +        stack.enter_context(DisableQuantization(model))  | 
 | 25 | +        stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True))  | 
 | 26 | + | 
 | 27 | +        # Find one MoE layer  | 
 | 28 | +        moe_layer = None  | 
 | 29 | +        for name, module in model.named_modules():  | 
 | 30 | +            if isinstance(module, CalibrationQwen3NextSparseMoeBlock):  | 
 | 31 | +                moe_layer = module  | 
 | 32 | +                break  | 
 | 33 | + | 
 | 34 | +        assert moe_layer is not None  | 
 | 35 | + | 
 | 36 | +        num_experts = len(moe_layer.experts)  | 
 | 37 | +        expert_triggered = [False for _ in range(num_experts)]  | 
 | 38 | + | 
 | 39 | +        # Define the hook function  | 
 | 40 | +        def hook_fn(i, module, input, output):  | 
 | 41 | +            expert_triggered[i] = True  | 
 | 42 | + | 
 | 43 | +        # Attach hooks using functools.partial to bind each index  | 
 | 44 | +        for i, expert in enumerate(moe_layer.experts):  | 
 | 45 | +            expert.register_forward_hook(partial(hook_fn, i))  | 
 | 46 | + | 
 | 47 | +        # Create dummy input tensor that simulates hidden_states  | 
 | 48 | +        hidden_dim = model.config.hidden_size  | 
 | 49 | +        batch, seq_len = 4, 32  | 
 | 50 | +        sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32)  | 
 | 51 | + | 
 | 52 | +        # Forward through the MoE layer directly  | 
 | 53 | +        with torch.no_grad():  | 
 | 54 | +            _ = moe_layer(sample)  | 
 | 55 | + | 
 | 56 | +        # Assert all experts are used  | 
 | 57 | +        assert all(  | 
 | 58 | +            expert_triggered  | 
 | 59 | +        ), f"Not all experts were triggered: {expert_triggered}"  | 
 | 60 | + | 
 | 61 | + | 
 | 62 | +@requires_gpu  | 
 | 63 | +def test_calib_qwen3_moe_module():  | 
 | 64 | +    from transformers import Qwen3NextConfig  | 
 | 65 | +    from transformers.models.qwen3_next.modeling_qwen3_next import (  | 
 | 66 | +        Qwen3NextSparseMoeBlock,  | 
 | 67 | +    )  | 
 | 68 | + | 
 | 69 | +    config = Qwen3NextConfig()  | 
 | 70 | +    with torch.device("cuda"):  | 
 | 71 | +        original = Qwen3NextSparseMoeBlock(config).eval()  | 
 | 72 | + | 
 | 73 | +    # Create dummy input tensor that simulates hidden_states  | 
 | 74 | +    hidden_dim = config.hidden_size  | 
 | 75 | +    batch, seq_len = 4, 32  | 
 | 76 | +    sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")  | 
 | 77 | + | 
 | 78 | +    with calibration_forward_context(original):  | 
 | 79 | +        true_output = original(sample)  | 
 | 80 | + | 
 | 81 | +    module = CalibrationQwen3NextSparseMoeBlock(  | 
 | 82 | +        original, config, calibrate_all_experts=True  | 
 | 83 | +    )  | 
 | 84 | + | 
 | 85 | +    with calibration_forward_context(module):  | 
 | 86 | +        output = module(sample)  | 
 | 87 | +        assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10  | 
 | 88 | +        assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10  | 
 | 89 | + | 
 | 90 | +    module = CalibrationQwen3NextSparseMoeBlock(  | 
 | 91 | +        original, config, calibrate_all_experts=False  | 
 | 92 | +    )  | 
 | 93 | +    with calibration_forward_context(module):  | 
 | 94 | +        output = module(sample)  | 
 | 95 | +        assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10  | 
 | 96 | +        assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10  | 
0 commit comments