diff --git a/tests/model_alignment/qwen_align_test.py b/tests/model_alignment/qwen_align_test.py index bcb96813..4a06ab72 100644 --- a/tests/model_alignment/qwen_align_test.py +++ b/tests/model_alignment/qwen_align_test.py @@ -16,6 +16,8 @@ import transformers from tunix.models.qwen2 import model as qwen2_model from tunix.models.qwen2 import params as qwen2_params +from tunix.models.qwen3 import model as qwen3_model +from tunix.models.qwen3 import params as qwen3_params from tunix.sft import utils from tunix.tests import test_common as tc @@ -74,20 +76,35 @@ def get_per_layer_hf_output(model, seq_len: int, num_layer_to_run: int = 1): return logits[0].detach().numpy() -def get_per_layer_jax_output(model, seq_len: int, num_layer_to_run: int = 1): +def get_per_layer_jax_output( + model_name: str, model, seq_len: int, num_layer_to_run: int = 1 +): """Get the first decoder layer output from the Tunix model.""" x = (jnp.arange(seq_len) + 1).reshape(1, -1) positions = jnp.arange(seq_len).reshape(1, -1) attn_mask = utils.make_causal_attn_mask(jnp.ones((1, seq_len))) - sin, cos = qwen2_model._generate_pos_embeddings( # pylint: disable=protected-access - positions, model.config.head_dim, model.config.rope_theta - ) - logits = model.embedder.encode(x) - for i in range(num_layer_to_run): - _, logits = model.layers[i](logits, None, attn_mask, sin, cos) + if model_name in [ + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "Qwen/Qwen2.5-1.5B-Instruct", + ]: + sin, cos = qwen2_model._generate_pos_embeddings( # pylint: disable=protected-access + positions, model.config.head_dim, model.config.rope_theta + ) + + logits = model.embedder.encode(x) + for i in range(num_layer_to_run): + _, logits = model.layers[i](logits, None, attn_mask, sin, cos) + + return logits + elif model_name == "Qwen/Qwen3-0.6B": + logits = model.embedder.encode(x) + for i in range(num_layer_to_run): + _, logits = model.layers[i](logits, positions, None, attn_mask) - return logits + return logits + else: + raise ValueError(f"Unsupported model: {model_name}") class QwenAlignTest(parameterized.TestCase): @@ -97,17 +114,27 @@ class QwenAlignTest(parameterized.TestCase): testcase_name="deepseek_r1_distill_qwen_1_5b", model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", model_config=qwen2_model.ModelConfig.deepseek_r1_distill_qwen_1_5b, + model_params=qwen2_params, tolerance=2e-3, ), dict( testcase_name="qwen2_5_1_5b_instruct", model_name="Qwen/Qwen2.5-1.5B-Instruct", model_config=qwen2_model.ModelConfig.qwen2_5_1_5b, + model_params=qwen2_params, + tolerance=1e-3, + ), + dict( + testcase_name="qwen3_0_6b", + model_name="Qwen/Qwen3-0.6B", + model_config=qwen3_model.ModelConfig.qwen3_0_6b, + model_params=qwen3_params, tolerance=1e-3, ), - # Note: Qwen/Qwen2.5-7B-Instruct will OOM on v5e-8. ) - def test_qwen_model_alignment(self, model_name, model_config, tolerance): + def test_qwen_model_alignment( + self, model_name, model_config, model_params, tolerance + ): model_path = os.path.join(tempfile.gettempdir(), "models", model_name) tc.download_from_huggingface(repo_id=model_name, model_path=model_path) @@ -117,7 +144,7 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance): ) print("HF model loaded.") - jax_model = qwen2_params.create_model_from_safe_tensors( + jax_model = model_params.create_model_from_safe_tensors( model_path, model_config(), mesh=jax.make_mesh((1, 1), ("fsdp", "tp")), @@ -148,7 +175,9 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance): layer_to_run = model_config().num_layers hf_logits = get_per_layer_hf_output(hf_model, seq_len, layer_to_run) - jax_logits = get_per_layer_jax_output(jax_model, seq_len, layer_to_run) + jax_logits = get_per_layer_jax_output( + model_name, jax_model, seq_len, layer_to_run + ) np.testing.assert_allclose( hf_logits.squeeze(), jax_logits.squeeze(),