|  | 
|  | 1 | +"""Check the model correctness for Tunix nnx implemented models. | 
|  | 2 | +
 | 
|  | 3 | +The test will compare the first N decoder layer output between Tunix model and | 
|  | 4 | +HF PyTorch model, typically we will expect the logits differnece to be within | 
|  | 5 | +1e-3 in fp32. | 
|  | 6 | +""" | 
|  | 7 | + | 
|  | 8 | +import os | 
|  | 9 | +import tempfile | 
|  | 10 | +from absl.testing import absltest | 
|  | 11 | +from absl.testing import parameterized | 
|  | 12 | +import jax | 
|  | 13 | +import jax.numpy as jnp | 
|  | 14 | +import numpy as np | 
|  | 15 | +import torch | 
|  | 16 | +import transformers | 
|  | 17 | +from tunix.models.gemma3 import model as gemma3_model | 
|  | 18 | +from tunix.models.gemma3 import params_safetensors as gemma3_params | 
|  | 19 | +from tunix.sft import utils | 
|  | 20 | +from tunix.tests import test_common as tc | 
|  | 21 | + | 
|  | 22 | +K_MASK = -2.3819763e38 | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +def create_pytorch_causal_mask(seq_len): | 
|  | 26 | +  """Creates a causal attention mask for a sequence of a given length. | 
|  | 27 | +
 | 
|  | 28 | +  Args: | 
|  | 29 | +    seq_len: The length of the sequence. | 
|  | 30 | +
 | 
|  | 31 | +  Returns: | 
|  | 32 | +    A boolean tensor of shape (seq_len, seq_len) where: | 
|  | 33 | +    - mask[i, j] is True if token i can attend to token j (j <= i). | 
|  | 34 | +    - mask[i, j] is False if token i cannot attend to token j (j > i). | 
|  | 35 | +  """ | 
|  | 36 | +  # Create a lower triangular matrix of ones | 
|  | 37 | +  mask = torch.ones(seq_len, seq_len, dtype=torch.float).tril(diagonal=0) | 
|  | 38 | +  mask = mask.masked_fill(mask == 0, K_MASK) | 
|  | 39 | +  mask = mask.masked_fill(mask == 1, 0) | 
|  | 40 | +  return mask | 
|  | 41 | + | 
|  | 42 | + | 
|  | 43 | +def get_hf_output(model, seq_len: int): | 
|  | 44 | +  x = (torch.arange(seq_len) + 1).reshape(1, -1) | 
|  | 45 | +  position_ids = torch.arange(seq_len).reshape(1, -1) | 
|  | 46 | +  attn_mask = create_pytorch_causal_mask(seq_len).unsqueeze(0).unsqueeze(0) | 
|  | 47 | +  return model(x, attn_mask, position_ids).logits.detach().numpy() | 
|  | 48 | + | 
|  | 49 | + | 
|  | 50 | +def get_jax_output(model, seq_len: int): | 
|  | 51 | +  x = (jnp.arange(seq_len) + 1).reshape(1, -1) | 
|  | 52 | +  positions = jnp.arange(seq_len).reshape(1, -1) | 
|  | 53 | +  attn_mask = utils.make_causal_attn_mask(jnp.ones((1, seq_len))) | 
|  | 54 | +  output, _ = model(x, positions, None, attn_mask) | 
|  | 55 | +  return output | 
|  | 56 | + | 
|  | 57 | + | 
|  | 58 | +class GemmaAlignTest(parameterized.TestCase): | 
|  | 59 | + | 
|  | 60 | +  @parameterized.named_parameters( | 
|  | 61 | +      dict( | 
|  | 62 | +          testcase_name="gemma3_270m_it", | 
|  | 63 | +          model_name="google/gemma-3-270m-it", | 
|  | 64 | +          model_config=gemma3_model.ModelConfig.gemma3_270m, | 
|  | 65 | +          tolerance=1e-3, | 
|  | 66 | +      ), | 
|  | 67 | +  ) | 
|  | 68 | +  def test_gemma_model_alignment(self, model_name, model_config, tolerance): | 
|  | 69 | +    model_path = os.path.join(tempfile.gettempdir(), "models", model_name) | 
|  | 70 | + | 
|  | 71 | +    tc.download_from_huggingface(repo_id=model_name, model_path=model_path) | 
|  | 72 | + | 
|  | 73 | +    hf_model = transformers.AutoModelForCausalLM.from_pretrained( | 
|  | 74 | +        model_path, dtype=torch.float32 | 
|  | 75 | +    ) | 
|  | 76 | +    print("HF model loaded.") | 
|  | 77 | + | 
|  | 78 | +    jax_model = gemma3_params.create_model_from_safe_tensors( | 
|  | 79 | +        model_path, | 
|  | 80 | +        model_config(), | 
|  | 81 | +        mesh=jax.make_mesh((1, 1), ("fsdp", "tp")), | 
|  | 82 | +        dtype=jnp.float32, | 
|  | 83 | +    ) | 
|  | 84 | +    print("JAX model loaded.") | 
|  | 85 | + | 
|  | 86 | +    # Make sure model weights are the same (only check the first query weight) | 
|  | 87 | +    hf_emb_weight = hf_model.get_decoder().embed_tokens.weight.detach().numpy() | 
|  | 88 | +    jax_emb_weight = jax_model.embedder.input_embedding.value | 
|  | 89 | +    np.testing.assert_equal( | 
|  | 90 | +        hf_emb_weight, | 
|  | 91 | +        jax_emb_weight, | 
|  | 92 | +        err_msg=( | 
|  | 93 | +            "Embedding weights are not equal, are you sure the loaded model" | 
|  | 94 | +            " weight between HF and JAX is identical?" | 
|  | 95 | +        ), | 
|  | 96 | +    ) | 
|  | 97 | +    hf_query_weight = ( | 
|  | 98 | +        hf_model.get_decoder() | 
|  | 99 | +        .layers[0] | 
|  | 100 | +        .self_attn.q_proj.weight.detach() | 
|  | 101 | +        .numpy() | 
|  | 102 | +    ) | 
|  | 103 | +    jax_query_weight = jax_model.layers[0].attn.q_einsum.w | 
|  | 104 | +    _, d, _ = jax_query_weight.shape | 
|  | 105 | +    jax_query_weight = jax_query_weight.transpose(0, 2, 1).reshape(-1, d) | 
|  | 106 | +    np.testing.assert_equal( | 
|  | 107 | +        hf_query_weight, | 
|  | 108 | +        jax_query_weight, | 
|  | 109 | +        err_msg=( | 
|  | 110 | +            "Query weights are not equal, are you sure the loaded model weight" | 
|  | 111 | +            " between HF and JAX is identical?" | 
|  | 112 | +        ), | 
|  | 113 | +    ) | 
|  | 114 | +    print("Model weights check passed :)") | 
|  | 115 | + | 
|  | 116 | +    seq_len = 128 | 
|  | 117 | + | 
|  | 118 | +    # Do a check on entire model output | 
|  | 119 | +    hf_output = get_hf_output(hf_model, seq_len) | 
|  | 120 | +    jax_output = get_jax_output(jax_model, seq_len) | 
|  | 121 | +    np.testing.assert_allclose( | 
|  | 122 | +        hf_output.squeeze(), | 
|  | 123 | +        jax_output.squeeze(), | 
|  | 124 | +        atol=tolerance, | 
|  | 125 | +        rtol=tolerance, | 
|  | 126 | +    ) | 
|  | 127 | + | 
|  | 128 | +    print("Logits are close! Model alignment check passed :)") | 
|  | 129 | + | 
|  | 130 | +    # clean up | 
|  | 131 | +    tc.delete_directory(model_path) | 
|  | 132 | + | 
|  | 133 | + | 
|  | 134 | +if __name__ == "__main__": | 
|  | 135 | +  absltest.main() | 
0 commit comments