1616import transformers
1717from tunix .models .qwen2 import model as qwen2_model
1818from tunix .models .qwen2 import params as qwen2_params
19+ from tunix .models .qwen3 import model as qwen3_model
20+ from tunix .models .qwen3 import params as qwen3_params
1921from tunix .sft import utils
2022from tunix .tests import test_common as tc
2123
@@ -74,20 +76,35 @@ def get_per_layer_hf_output(model, seq_len: int, num_layer_to_run: int = 1):
7476 return logits [0 ].detach ().numpy ()
7577
7678
77- def get_per_layer_jax_output (model , seq_len : int , num_layer_to_run : int = 1 ):
79+ def get_per_layer_jax_output (
80+ model_name : str , model , seq_len : int , num_layer_to_run : int = 1
81+ ):
7882 """Get the first decoder layer output from the Tunix model."""
7983 x = (jnp .arange (seq_len ) + 1 ).reshape (1 , - 1 )
8084 positions = jnp .arange (seq_len ).reshape (1 , - 1 )
8185 attn_mask = utils .make_causal_attn_mask (jnp .ones ((1 , seq_len )))
82- sin , cos = qwen2_model ._generate_pos_embeddings ( # pylint: disable=protected-access
83- positions , model .config .head_dim , model .config .rope_theta
84- )
8586
86- logits = model .embedder .encode (x )
87- for i in range (num_layer_to_run ):
88- _ , logits = model .layers [i ](logits , None , attn_mask , sin , cos )
87+ if model_name in [
88+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" ,
89+ "Qwen/Qwen2.5-1.5B-Instruct" ,
90+ ]:
91+ sin , cos = qwen2_model ._generate_pos_embeddings ( # pylint: disable=protected-access
92+ positions , model .config .head_dim , model .config .rope_theta
93+ )
94+
95+ logits = model .embedder .encode (x )
96+ for i in range (num_layer_to_run ):
97+ _ , logits = model .layers [i ](logits , None , attn_mask , sin , cos )
98+
99+ return logits
100+ elif model_name == "Qwen/Qwen3-0.6B" :
101+ logits = model .embedder .encode (x )
102+ for i in range (num_layer_to_run ):
103+ _ , logits = model .layers [i ](logits , positions , None , attn_mask )
89104
90- return logits
105+ return logits
106+ else :
107+ raise ValueError (f"Unsupported model: { model_name } " )
91108
92109
93110class QwenAlignTest (parameterized .TestCase ):
@@ -97,17 +114,27 @@ class QwenAlignTest(parameterized.TestCase):
97114 testcase_name = "deepseek_r1_distill_qwen_1_5b" ,
98115 model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" ,
99116 model_config = qwen2_model .ModelConfig .deepseek_r1_distill_qwen_1_5b ,
117+ model_params = qwen2_params ,
100118 tolerance = 2e-3 ,
101119 ),
102120 dict (
103121 testcase_name = "qwen2_5_1_5b_instruct" ,
104122 model_name = "Qwen/Qwen2.5-1.5B-Instruct" ,
105123 model_config = qwen2_model .ModelConfig .qwen2_5_1_5b ,
124+ model_params = qwen2_params ,
125+ tolerance = 1e-3 ,
126+ ),
127+ dict (
128+ testcase_name = "qwen3_0_6b" ,
129+ model_name = "Qwen/Qwen3-0.6B" ,
130+ model_config = qwen3_model .ModelConfig .qwen3_0_6b ,
131+ model_params = qwen3_params ,
106132 tolerance = 1e-3 ,
107133 ),
108- # Note: Qwen/Qwen2.5-7B-Instruct will OOM on v5e-8.
109134 )
110- def test_qwen_model_alignment (self , model_name , model_config , tolerance ):
135+ def test_qwen_model_alignment (
136+ self , model_name , model_config , model_params , tolerance
137+ ):
111138 model_path = os .path .join (tempfile .gettempdir (), "models" , model_name )
112139
113140 tc .download_from_huggingface (repo_id = model_name , model_path = model_path )
@@ -117,7 +144,7 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
117144 )
118145 print ("HF model loaded." )
119146
120- jax_model = qwen2_params .create_model_from_safe_tensors (
147+ jax_model = model_params .create_model_from_safe_tensors (
121148 model_path ,
122149 model_config (),
123150 mesh = jax .make_mesh ((1 , 1 ), ("fsdp" , "tp" )),
@@ -148,7 +175,9 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
148175
149176 layer_to_run = model_config ().num_layers
150177 hf_logits = get_per_layer_hf_output (hf_model , seq_len , layer_to_run )
151- jax_logits = get_per_layer_jax_output (jax_model , seq_len , layer_to_run )
178+ jax_logits = get_per_layer_jax_output (
179+ model_name , jax_model , seq_len , layer_to_run
180+ )
152181 np .testing .assert_allclose (
153182 hf_logits .squeeze (),
154183 jax_logits .squeeze (),
0 commit comments