Skip to content

Commit bbc354e

Browse files
jiangyangmuThe tunix Authors
authored andcommitted
test: add model alignment test for qwen3 models.
PiperOrigin-RevId: 824835255
1 parent 02f6b3b commit bbc354e

File tree

1 file changed

+41
-12
lines changed

1 file changed

+41
-12
lines changed

tests/model_alignment/qwen_align_test.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import transformers
1717
from tunix.models.qwen2 import model as qwen2_model
1818
from tunix.models.qwen2 import params as qwen2_params
19+
from tunix.models.qwen3 import model as qwen3_model
20+
from tunix.models.qwen3 import params as qwen3_params
1921
from tunix.sft import utils
2022
from tunix.tests import test_common as tc
2123

@@ -74,20 +76,35 @@ def get_per_layer_hf_output(model, seq_len: int, num_layer_to_run: int = 1):
7476
return logits[0].detach().numpy()
7577

7678

77-
def get_per_layer_jax_output(model, seq_len: int, num_layer_to_run: int = 1):
79+
def get_per_layer_jax_output(
80+
model_name: str, model, seq_len: int, num_layer_to_run: int = 1
81+
):
7882
"""Get the first decoder layer output from the Tunix model."""
7983
x = (jnp.arange(seq_len) + 1).reshape(1, -1)
8084
positions = jnp.arange(seq_len).reshape(1, -1)
8185
attn_mask = utils.make_causal_attn_mask(jnp.ones((1, seq_len)))
82-
sin, cos = qwen2_model._generate_pos_embeddings( # pylint: disable=protected-access
83-
positions, model.config.head_dim, model.config.rope_theta
84-
)
8586

86-
logits = model.embedder.encode(x)
87-
for i in range(num_layer_to_run):
88-
_, logits = model.layers[i](logits, None, attn_mask, sin, cos)
87+
if model_name in [
88+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
89+
"Qwen/Qwen2.5-1.5B-Instruct",
90+
]:
91+
sin, cos = qwen2_model._generate_pos_embeddings( # pylint: disable=protected-access
92+
positions, model.config.head_dim, model.config.rope_theta
93+
)
94+
95+
logits = model.embedder.encode(x)
96+
for i in range(num_layer_to_run):
97+
_, logits = model.layers[i](logits, None, attn_mask, sin, cos)
98+
99+
return logits
100+
elif model_name == "Qwen/Qwen3-0.6B":
101+
logits = model.embedder.encode(x)
102+
for i in range(num_layer_to_run):
103+
_, logits = model.layers[i](logits, positions, None, attn_mask)
89104

90-
return logits
105+
return logits
106+
else:
107+
raise ValueError(f"Unsupported model: {model_name}")
91108

92109

93110
class QwenAlignTest(parameterized.TestCase):
@@ -97,17 +114,27 @@ class QwenAlignTest(parameterized.TestCase):
97114
testcase_name="deepseek_r1_distill_qwen_1_5b",
98115
model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
99116
model_config=qwen2_model.ModelConfig.deepseek_r1_distill_qwen_1_5b,
117+
model_params=qwen2_params,
100118
tolerance=2e-3,
101119
),
102120
dict(
103121
testcase_name="qwen2_5_1_5b_instruct",
104122
model_name="Qwen/Qwen2.5-1.5B-Instruct",
105123
model_config=qwen2_model.ModelConfig.qwen2_5_1_5b,
124+
model_params=qwen2_params,
125+
tolerance=1e-3,
126+
),
127+
dict(
128+
testcase_name="qwen3_0_6b",
129+
model_name="Qwen/Qwen3-0.6B",
130+
model_config=qwen3_model.ModelConfig.qwen3_0_6b,
131+
model_params=qwen3_params,
106132
tolerance=1e-3,
107133
),
108-
# Note: Qwen/Qwen2.5-7B-Instruct will OOM on v5e-8.
109134
)
110-
def test_qwen_model_alignment(self, model_name, model_config, tolerance):
135+
def test_qwen_model_alignment(
136+
self, model_name, model_config, model_params, tolerance
137+
):
111138
model_path = os.path.join(tempfile.gettempdir(), "models", model_name)
112139

113140
tc.download_from_huggingface(repo_id=model_name, model_path=model_path)
@@ -117,7 +144,7 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
117144
)
118145
print("HF model loaded.")
119146

120-
jax_model = qwen2_params.create_model_from_safe_tensors(
147+
jax_model = model_params.create_model_from_safe_tensors(
121148
model_path,
122149
model_config(),
123150
mesh=jax.make_mesh((1, 1), ("fsdp", "tp")),
@@ -148,7 +175,9 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
148175

149176
layer_to_run = model_config().num_layers
150177
hf_logits = get_per_layer_hf_output(hf_model, seq_len, layer_to_run)
151-
jax_logits = get_per_layer_jax_output(jax_model, seq_len, layer_to_run)
178+
jax_logits = get_per_layer_jax_output(
179+
model_name, jax_model, seq_len, layer_to_run
180+
)
152181
np.testing.assert_allclose(
153182
hf_logits.squeeze(),
154183
jax_logits.squeeze(),

0 commit comments

Comments
 (0)