Skip to content

Commit 3fe890d

Browse files
jiangyangmuThe tunix Authors
authored andcommitted
test: add model alignment test for qwen3 models.
PiperOrigin-RevId: 824835255
1 parent 763fcd5 commit 3fe890d

File tree

2 files changed

+160
-12
lines changed

2 files changed

+160
-12
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Check the model correctness for Tunix nnx implemented models.
2+
3+
The test will compare the first N decoder layer output between Tunix model and
4+
HF PyTorch model, typically we will expect the logits differnece to be within
5+
1e-3 in fp32.
6+
"""
7+
8+
import os
9+
import tempfile
10+
from absl.testing import absltest
11+
from absl.testing import parameterized
12+
import jax
13+
import jax.numpy as jnp
14+
import numpy as np
15+
import torch
16+
import transformers
17+
from tunix.models.gemma3 import model as gemma3_model
18+
from tunix.models.gemma3 import params_safetensors as gemma3_params
19+
from tunix.sft import utils
20+
from tunix.tests import test_common as tc
21+
22+
K_MASK = -2.3819763e38
23+
24+
25+
def create_pytorch_causal_mask(seq_len):
26+
"""Creates a causal attention mask for a sequence of a given length.
27+
28+
Args:
29+
seq_len: The length of the sequence.
30+
31+
Returns:
32+
A boolean tensor of shape (seq_len, seq_len) where:
33+
- mask[i, j] is True if token i can attend to token j (j <= i).
34+
- mask[i, j] is False if token i cannot attend to token j (j > i).
35+
"""
36+
# Create a lower triangular matrix of ones
37+
mask = torch.ones(seq_len, seq_len, dtype=torch.float).tril(diagonal=0)
38+
mask = mask.masked_fill(mask == 0, K_MASK)
39+
mask = mask.masked_fill(mask == 1, 0)
40+
return mask
41+
42+
43+
def get_hf_output(model, seq_len: int):
44+
x = (torch.arange(seq_len) + 1).reshape(1, -1)
45+
position_ids = torch.arange(seq_len).reshape(1, -1)
46+
attn_mask = create_pytorch_causal_mask(seq_len).unsqueeze(0).unsqueeze(0)
47+
return model(x, attn_mask, position_ids).logits.detach().numpy()
48+
49+
50+
def get_jax_output(model, seq_len: int):
51+
x = (jnp.arange(seq_len) + 1).reshape(1, -1)
52+
positions = jnp.arange(seq_len).reshape(1, -1)
53+
attn_mask = utils.make_causal_attn_mask(jnp.ones((1, seq_len)))
54+
output, _ = model(x, positions, None, attn_mask)
55+
return output
56+
57+
58+
class GemmaAlignTest(parameterized.TestCase):
59+
60+
@parameterized.named_parameters(
61+
dict(
62+
testcase_name="gemma3_270m_it",
63+
model_name="google/gemma-3-270m-it",
64+
model_config=gemma3_model.ModelConfig.gemma3_270m,
65+
tolerance=1e-3,
66+
),
67+
)
68+
def test_gemma_model_alignment(self, model_name, model_config, tolerance):
69+
model_path = os.path.join(tempfile.gettempdir(), "models", model_name)
70+
71+
tc.download_from_huggingface(repo_id=model_name, model_path=model_path)
72+
73+
hf_model = transformers.AutoModelForCausalLM.from_pretrained(
74+
model_path, dtype=torch.float32
75+
)
76+
print("HF model loaded.")
77+
78+
jax_model = gemma3_params.create_model_from_safe_tensors(
79+
model_path,
80+
model_config(),
81+
mesh=jax.make_mesh((1, 1), ("fsdp", "tp")),
82+
dtype=jnp.float32,
83+
)
84+
print("JAX model loaded.")
85+
86+
# Make sure model weights are the same (only check the first query weight)
87+
hf_emb_weight = hf_model.get_decoder().embed_tokens.weight.detach().numpy()
88+
jax_emb_weight = jax_model.embedder.input_embedding.value
89+
np.testing.assert_equal(
90+
hf_emb_weight,
91+
jax_emb_weight,
92+
err_msg=(
93+
"Embedding weights are not equal, are you sure the loaded model"
94+
" weight between HF and JAX is identical?"
95+
),
96+
)
97+
hf_query_weight = (
98+
hf_model.get_decoder()
99+
.layers[0]
100+
.self_attn.q_proj.weight.detach()
101+
.numpy()
102+
)
103+
jax_query_weight = jax_model.layers[0].attn.q_einsum.w
104+
_, d, _ = jax_query_weight.shape
105+
jax_query_weight = jax_query_weight.transpose(0, 2, 1).reshape(-1, d)
106+
np.testing.assert_equal(
107+
hf_query_weight,
108+
jax_query_weight,
109+
err_msg=(
110+
"Query weights are not equal, are you sure the loaded model weight"
111+
" between HF and JAX is identical?"
112+
),
113+
)
114+
print("Model weights check passed :)")
115+
116+
seq_len = 128
117+
118+
# Do a check on entire model output
119+
hf_output = get_hf_output(hf_model, seq_len)
120+
jax_output = get_jax_output(jax_model, seq_len)
121+
np.testing.assert_allclose(
122+
hf_output.squeeze(),
123+
jax_output.squeeze(),
124+
atol=tolerance,
125+
rtol=tolerance,
126+
)
127+
128+
print("Logits are close! Model alignment check passed :)")
129+
130+
# clean up
131+
tc.delete_directory(model_path)
132+
133+
134+
if __name__ == "__main__":
135+
absltest.main()

tests/model_alignment/qwen_align_test.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import transformers
1717
from tunix.models.qwen2 import model as qwen2_model
1818
from tunix.models.qwen2 import params as qwen2_params
19+
from tunix.models.qwen3 import model as qwen3_model
20+
from tunix.models.qwen3 import params as qwen3_params
1921
from tunix.sft import utils
2022
from tunix.tests import test_common as tc
2123

@@ -97,17 +99,27 @@ class QwenAlignTest(parameterized.TestCase):
9799
testcase_name="deepseek_r1_distill_qwen_1_5b",
98100
model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
99101
model_config=qwen2_model.ModelConfig.deepseek_r1_distill_qwen_1_5b,
102+
model_loader=qwen2_params,
100103
tolerance=2e-3,
101104
),
102105
dict(
103106
testcase_name="qwen2_5_1_5b_instruct",
104107
model_name="Qwen/Qwen2.5-1.5B-Instruct",
105108
model_config=qwen2_model.ModelConfig.qwen2_5_1_5b,
109+
model_loader=qwen2_params,
110+
tolerance=1e-3,
111+
),
112+
dict(
113+
testcase_name="qwen3_0_6b",
114+
model_name="Qwen/Qwen3-0.6B",
115+
model_config=qwen3_model.ModelConfig.qwen3_0_6b,
116+
model_loader=qwen3_params,
106117
tolerance=1e-3,
107118
),
108-
# Note: Qwen/Qwen2.5-7B-Instruct will OOM on v5e-8.
109119
)
110-
def test_qwen_model_alignment(self, model_name, model_config, tolerance):
120+
def test_qwen_model_alignment(
121+
self, model_name, model_config, model_loader, tolerance
122+
):
111123
model_path = os.path.join(tempfile.gettempdir(), "models", model_name)
112124

113125
tc.download_from_huggingface(repo_id=model_name, model_path=model_path)
@@ -117,7 +129,7 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
117129
)
118130
print("HF model loaded.")
119131

120-
jax_model = qwen2_params.create_model_from_safe_tensors(
132+
jax_model = model_loader.create_model_from_safe_tensors(
121133
model_path,
122134
model_config(),
123135
mesh=jax.make_mesh((1, 1), ("fsdp", "tp")),
@@ -146,15 +158,16 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
146158

147159
seq_len = 128
148160

149-
layer_to_run = model_config().num_layers
150-
hf_logits = get_per_layer_hf_output(hf_model, seq_len, layer_to_run)
151-
jax_logits = get_per_layer_jax_output(jax_model, seq_len, layer_to_run)
152-
np.testing.assert_allclose(
153-
hf_logits.squeeze(),
154-
jax_logits.squeeze(),
155-
atol=tolerance,
156-
rtol=tolerance,
157-
)
161+
if model_loader == qwen2_params:
162+
layer_to_run = model_config().num_layers
163+
hf_logits = get_per_layer_hf_output(hf_model, seq_len, layer_to_run)
164+
jax_logits = get_per_layer_jax_output(jax_model, seq_len, layer_to_run)
165+
np.testing.assert_allclose(
166+
hf_logits.squeeze(),
167+
jax_logits.squeeze(),
168+
atol=tolerance,
169+
rtol=tolerance,
170+
)
158171

159172
# Do a check on entire model output
160173
hf_output = get_hf_output(hf_model, seq_len)

0 commit comments

Comments
 (0)