@@ -26,6 +26,7 @@ class Einsum(nn.Module):
2626 weight_name : str = 'w'
2727 initializer : nn .initializers .Initializer = nn .initializers .normal ()
2828 dtype : jnp .dtype | None = None
29+ w_scale : float | None = None
2930
3031 @nn .compact
3132 def __call__ (self , eqn : str , x : jax .Array ) -> jax .Array :
@@ -35,24 +36,42 @@ def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
3536 self .shape ,
3637 self .dtype if self .dtype is not None else None ,
3738 )
39+ if self .w_scale :
40+ w *= self .w_scale
3841 return jnp .einsum (eqn , x , w )
3942
4043
44+ def reduce_precision (x : jax .Array ) -> jax .Array :
45+ """Helper function to reduce the precision of a tensor."""
46+ finfo = jnp .finfo (x .dtype ) # jnp important!
47+ return jax .lax .reduce_precision (x , finfo .nexp , finfo .nmant )
48+
49+
4150class RMSNorm (nn .Module ):
4251 """RMSNorm layer."""
4352
53+ with_scale : bool = True
54+ scale_init : nn .initializers .Initializer = nn .initializers .zeros_init ()
55+ scale_plus_one : bool = True
56+ guard_against_excess_precision : bool = False
57+
4458 @nn .compact
4559 def __call__ (self , x ):
46- scale = self .param ('scale' , nn .initializers .zeros_init (), (x .shape [- 1 ]))
60+ if self .guard_against_excess_precision :
61+ x = reduce_precision (x )
62+
4763 var = jnp .mean (jnp .square (x ), axis = - 1 , keepdims = True )
4864
4965 # Jax.lax.rsqrt is used because it returns different floats than
5066 # jnp.reciprocal(jnp.sqrt(var + 1e-06))
5167 normed_inputs = x * jax .lax .rsqrt (var + 1e-06 )
5268
53- # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
54- # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
55- # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
56- scale = jnp .expand_dims (scale , axis = range (len (x .shape ) - 1 ))
57- normed_inputs = normed_inputs * (1 + scale )
69+ if self .with_scale :
70+ scale = self .param ('scale' , self .scale_init , (x .shape [- 1 ]))
71+ # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale
72+ # is a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
73+ # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
74+ scale = jnp .expand_dims (scale , axis = range (len (x .shape ) - 1 ))
75+ normed_inputs = normed_inputs * (
76+ 1. + scale if self .scale_plus_one else scale )
5877 return normed_inputs
0 commit comments