Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions tests/model_alignment/qwen_align_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import transformers
from tunix.models.qwen2 import model as qwen2_model
from tunix.models.qwen2 import params as qwen2_params
from tunix.models.qwen3 import model as qwen3_model
from tunix.models.qwen3 import params as qwen3_params
from tunix.sft import utils
from tunix.tests import test_common as tc

Expand Down Expand Up @@ -74,20 +76,35 @@ def get_per_layer_hf_output(model, seq_len: int, num_layer_to_run: int = 1):
return logits[0].detach().numpy()


def get_per_layer_jax_output(model, seq_len: int, num_layer_to_run: int = 1):
def get_per_layer_jax_output(
model_name: str, model, seq_len: int, num_layer_to_run: int = 1
):
"""Get the first decoder layer output from the Tunix model."""
x = (jnp.arange(seq_len) + 1).reshape(1, -1)
positions = jnp.arange(seq_len).reshape(1, -1)
attn_mask = utils.make_causal_attn_mask(jnp.ones((1, seq_len)))
sin, cos = qwen2_model._generate_pos_embeddings( # pylint: disable=protected-access
positions, model.config.head_dim, model.config.rope_theta
)

logits = model.embedder.encode(x)
for i in range(num_layer_to_run):
_, logits = model.layers[i](logits, None, attn_mask, sin, cos)
if model_name in [
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"Qwen/Qwen2.5-1.5B-Instruct",
]:
sin, cos = qwen2_model._generate_pos_embeddings( # pylint: disable=protected-access
positions, model.config.head_dim, model.config.rope_theta
)

logits = model.embedder.encode(x)
for i in range(num_layer_to_run):
_, logits = model.layers[i](logits, None, attn_mask, sin, cos)

return logits
elif model_name == "Qwen/Qwen3-0.6B":
logits = model.embedder.encode(x)
for i in range(num_layer_to_run):
_, logits = model.layers[i](logits, positions, None, attn_mask)

return logits
return logits
else:
raise ValueError(f"Unsupported model: {model_name}")


class QwenAlignTest(parameterized.TestCase):
Expand All @@ -97,17 +114,27 @@ class QwenAlignTest(parameterized.TestCase):
testcase_name="deepseek_r1_distill_qwen_1_5b",
model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
model_config=qwen2_model.ModelConfig.deepseek_r1_distill_qwen_1_5b,
model_params=qwen2_params,
tolerance=2e-3,
),
dict(
testcase_name="qwen2_5_1_5b_instruct",
model_name="Qwen/Qwen2.5-1.5B-Instruct",
model_config=qwen2_model.ModelConfig.qwen2_5_1_5b,
model_params=qwen2_params,
tolerance=1e-3,
),
dict(
testcase_name="qwen3_0_6b",
model_name="Qwen/Qwen3-0.6B",
model_config=qwen3_model.ModelConfig.qwen3_0_6b,
model_params=qwen3_params,
tolerance=1e-3,
),
# Note: Qwen/Qwen2.5-7B-Instruct will OOM on v5e-8.
)
def test_qwen_model_alignment(self, model_name, model_config, tolerance):
def test_qwen_model_alignment(
self, model_name, model_config, model_params, tolerance
):
model_path = os.path.join(tempfile.gettempdir(), "models", model_name)

tc.download_from_huggingface(repo_id=model_name, model_path=model_path)
Expand All @@ -117,7 +144,7 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
)
print("HF model loaded.")

jax_model = qwen2_params.create_model_from_safe_tensors(
jax_model = model_params.create_model_from_safe_tensors(
model_path,
model_config(),
mesh=jax.make_mesh((1, 1), ("fsdp", "tp")),
Expand Down Expand Up @@ -148,7 +175,9 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):

layer_to_run = model_config().num_layers
hf_logits = get_per_layer_hf_output(hf_model, seq_len, layer_to_run)
jax_logits = get_per_layer_jax_output(jax_model, seq_len, layer_to_run)
jax_logits = get_per_layer_jax_output(
model_name, jax_model, seq_len, layer_to_run
)
np.testing.assert_allclose(
hf_logits.squeeze(),
jax_logits.squeeze(),
Expand Down