1616import  transformers 
1717from  tunix .models .qwen2  import  model  as  qwen2_model 
1818from  tunix .models .qwen2  import  params  as  qwen2_params 
19+ from  tunix .models .qwen3  import  model  as  qwen3_model 
20+ from  tunix .models .qwen3  import  params  as  qwen3_params 
1921from  tunix .sft  import  utils 
2022from  tunix .tests  import  test_common  as  tc 
2123
@@ -74,20 +76,35 @@ def get_per_layer_hf_output(model, seq_len: int, num_layer_to_run: int = 1):
7476  return  logits [0 ].detach ().numpy ()
7577
7678
77- def  get_per_layer_jax_output (model , seq_len : int , num_layer_to_run : int  =  1 ):
79+ def  get_per_layer_jax_output (
80+     model_name : str , model , seq_len : int , num_layer_to_run : int  =  1 
81+ ):
7882  """Get the first decoder layer output from the Tunix model.""" 
7983  x  =  (jnp .arange (seq_len ) +  1 ).reshape (1 , - 1 )
8084  positions  =  jnp .arange (seq_len ).reshape (1 , - 1 )
8185  attn_mask  =  utils .make_causal_attn_mask (jnp .ones ((1 , seq_len )))
82-   sin , cos  =  qwen2_model ._generate_pos_embeddings (  # pylint: disable=protected-access 
83-       positions , model .config .head_dim , model .config .rope_theta 
84-   )
8586
86-   logits  =  model .embedder .encode (x )
87-   for  i  in  range (num_layer_to_run ):
88-     _ , logits  =  model .layers [i ](logits , None , attn_mask , sin , cos )
87+   if  model_name  in  [
88+       "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" ,
89+       "Qwen/Qwen2.5-1.5B-Instruct" ,
90+   ]:
91+     sin , cos  =  qwen2_model ._generate_pos_embeddings (  # pylint: disable=protected-access 
92+         positions , model .config .head_dim , model .config .rope_theta 
93+     )
94+ 
95+     logits  =  model .embedder .encode (x )
96+     for  i  in  range (num_layer_to_run ):
97+       _ , logits  =  model .layers [i ](logits , None , attn_mask , sin , cos )
98+ 
99+     return  logits 
100+   elif  model_name  ==  "Qwen/Qwen3-0.6B" :
101+     logits  =  model .embedder .encode (x )
102+     for  i  in  range (num_layer_to_run ):
103+       _ , logits  =  model .layers [i ](logits , positions , None , attn_mask )
89104
90-   return  logits 
105+     return  logits 
106+   else :
107+     raise  ValueError (f"Unsupported model: { model_name }  )
91108
92109
93110class  QwenAlignTest (parameterized .TestCase ):
@@ -97,17 +114,27 @@ class QwenAlignTest(parameterized.TestCase):
97114          testcase_name = "deepseek_r1_distill_qwen_1_5b" , 
98115          model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" , 
99116          model_config = qwen2_model .ModelConfig .deepseek_r1_distill_qwen_1_5b , 
117+           model_params = qwen2_params , 
100118          tolerance = 2e-3 , 
101119      ), 
102120      dict ( 
103121          testcase_name = "qwen2_5_1_5b_instruct" , 
104122          model_name = "Qwen/Qwen2.5-1.5B-Instruct" , 
105123          model_config = qwen2_model .ModelConfig .qwen2_5_1_5b , 
124+           model_params = qwen2_params , 
125+           tolerance = 1e-3 , 
126+       ), 
127+       dict ( 
128+           testcase_name = "qwen3_0_6b" , 
129+           model_name = "Qwen/Qwen3-0.6B" , 
130+           model_config = qwen3_model .ModelConfig .qwen3_0_6b , 
131+           model_params = qwen3_params , 
106132          tolerance = 1e-3 , 
107133      ), 
108-       # Note: Qwen/Qwen2.5-7B-Instruct will OOM on v5e-8.  
109134  ) 
110-   def  test_qwen_model_alignment (self , model_name , model_config , tolerance ):
135+   def  test_qwen_model_alignment (
136+       self , model_name , model_config , model_params , tolerance 
137+   ):
111138    model_path  =  os .path .join (tempfile .gettempdir (), "models" , model_name )
112139
113140    tc .download_from_huggingface (repo_id = model_name , model_path = model_path )
@@ -117,7 +144,7 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
117144    )
118145    print ("HF model loaded." )
119146
120-     jax_model  =  qwen2_params .create_model_from_safe_tensors (
147+     jax_model  =  model_params .create_model_from_safe_tensors (
121148        model_path ,
122149        model_config (),
123150        mesh = jax .make_mesh ((1 , 1 ), ("fsdp" , "tp" )),
@@ -148,7 +175,9 @@ def test_qwen_model_alignment(self, model_name, model_config, tolerance):
148175
149176    layer_to_run  =  model_config ().num_layers 
150177    hf_logits  =  get_per_layer_hf_output (hf_model , seq_len , layer_to_run )
151-     jax_logits  =  get_per_layer_jax_output (jax_model , seq_len , layer_to_run )
178+     jax_logits  =  get_per_layer_jax_output (
179+         model_name , jax_model , seq_len , layer_to_run 
180+     )
152181    np .testing .assert_allclose (
153182        hf_logits .squeeze (),
154183        jax_logits .squeeze (),
0 commit comments