|
20 | 20 | from flax import linen as nn |
21 | 21 | from flax.typing import Array # pylint: disable=g-importing-member |
22 | 22 | from gemma.peft import _einsum_utils |
| 23 | +import jax |
23 | 24 | import jax.numpy as jnp |
| 25 | +import numpy as np |
24 | 26 |
|
25 | 27 |
|
26 | 28 | class LoRADenseAdapter(nn.Module): |
@@ -170,3 +172,109 @@ def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array: |
170 | 172 | b_init=self.b_init, |
171 | 173 | ) |
172 | 174 | return self.wrapped(inputs) + adapter(inputs) |
| 175 | + |
| 176 | + |
| 177 | +class LoRADenseGeneralAdapter(nn.Module): |
| 178 | + """LoRA general dense module. |
| 179 | +
|
| 180 | + This module only does the x @ A @ B computation. |
| 181 | + Use `LoRAGeneralDense` to wrap a `nn.Einsum` layer. |
| 182 | +
|
| 183 | + Attributes: |
| 184 | + rank: The rank of the LoRA decomposition. |
| 185 | + features: The number of output features. |
| 186 | + axis: int or tuple with axes to apply the transformation on. |
| 187 | + batch_dims: tuple with batch axes. |
| 188 | + dtype: The dtype to use for the LoRA weights. |
| 189 | + a_init: The initializer for the A matrix. |
| 190 | + b_init: The initializer for the B matrix. |
| 191 | + """ |
| 192 | + |
| 193 | + _: dataclasses.KW_ONLY |
| 194 | + |
| 195 | + rank: int |
| 196 | + features: int | Sequence[int] |
| 197 | + axis: int | Sequence[int] |
| 198 | + batch_dims: Sequence[int] |
| 199 | + |
| 200 | + dtype: jnp.dtype = jnp.float_ |
| 201 | + a_init: nn.initializers.Initializer = nn.initializers.kaiming_uniform() |
| 202 | + b_init: nn.initializers.Initializer = nn.initializers.zeros_init() |
| 203 | + |
| 204 | + @nn.compact |
| 205 | + def __call__(self, inputs: Array) -> Array: |
| 206 | + """Mostly copied from `flax.nn.DenseGeneral`.""" |
| 207 | + |
| 208 | + # Normalize inputs |
| 209 | + batch_dims = nn.linear._canonicalize_tuple(self.batch_dims) |
| 210 | + if batch_dims: |
| 211 | + max_dim = np.max(batch_dims) |
| 212 | + if set(batch_dims) != set(range(max_dim + 1)): |
| 213 | + raise ValueError( |
| 214 | + 'batch_dims %s must be consecutive leading ' |
| 215 | + 'dimensions starting from 0.' |
| 216 | + % str(batch_dims) |
| 217 | + ) |
| 218 | + |
| 219 | + n_dim = inputs.ndim |
| 220 | + batch_dims = nn.linear._normalize_axes(batch_dims, n_dim) |
| 221 | + features = nn.linear._canonicalize_tuple(self.features) |
| 222 | + axis = nn.linear._normalize_axes( |
| 223 | + nn.linear._canonicalize_tuple(self.axis), |
| 224 | + n_dim, |
| 225 | + ) |
| 226 | + |
| 227 | + # Create LoRA params |
| 228 | + batch_shape = tuple(inputs.shape[i] for i in batch_dims) |
| 229 | + a_shape = batch_shape + tuple(inputs.shape[i] for i in axis) + (self.rank,) |
| 230 | + a = self.param('a', self.a_init, a_shape, self.dtype) |
| 231 | + b_shape = (*batch_shape, self.rank, *features) |
| 232 | + b = self.param('b', self.b_init, b_shape, self.dtype) |
| 233 | + |
| 234 | + # Contract across given axes. |
| 235 | + n_batch_dims, n_axis = len(batch_dims), len(axis) |
| 236 | + batch_ind = tuple(range(n_batch_dims)) |
| 237 | + contract_ind = tuple(range(n_batch_dims, n_batch_dims + n_axis)) |
| 238 | + inputs = nn.dtypes.promote_dtype(inputs, dtype=self.dtype)[0] |
| 239 | + # low_rank = x @ A |
| 240 | + low_rank_dot = ((axis, contract_ind), (batch_dims, batch_ind)) |
| 241 | + low_rank = jax.lax.dot_general(inputs, a, low_rank_dot) |
| 242 | + # out = low_rank @ B |
| 243 | + low_rank_cind = [n_dim - n_axis] |
| 244 | + b_cind = [n_batch_dims] |
| 245 | + out_dot = ((low_rank_cind, b_cind), (batch_dims, batch_ind)) |
| 246 | + return jax.lax.dot_general(low_rank, b, out_dot) |
| 247 | + |
| 248 | + |
| 249 | +class LoRADenseGeneral(nn.Module): |
| 250 | + """Wrapper around `nn.DenseGeneral` which adds a LoRA adapter.""" |
| 251 | + |
| 252 | + _: dataclasses.KW_ONLY |
| 253 | + |
| 254 | + rank: int |
| 255 | + wrapped: nn.DenseGeneral |
| 256 | + |
| 257 | + dtype: jnp.dtype = jnp.float_ |
| 258 | + a_init: nn.initializers.Initializer = nn.initializers.kaiming_uniform() |
| 259 | + b_init: nn.initializers.Initializer = nn.initializers.zeros_init() |
| 260 | + |
| 261 | + def __post_init__(self): |
| 262 | + super().__post_init__() |
| 263 | + # Share scope, to make the wrapper module transparent with respect to the |
| 264 | + # parameters (instead of nesting `{'params': {'wrapped': params}}`). |
| 265 | + if self.scope is not None: |
| 266 | + nn.share_scope(self, self.wrapped) |
| 267 | + |
| 268 | + @nn.compact |
| 269 | + def __call__(self, inputs: Array) -> Array: |
| 270 | + adapter = LoRADenseGeneralAdapter( |
| 271 | + name='lora', |
| 272 | + rank=self.rank, |
| 273 | + features=self.wrapped.features, |
| 274 | + axis=self.wrapped.axis, |
| 275 | + batch_dims=self.wrapped.batch_dims, |
| 276 | + dtype=self.dtype, |
| 277 | + a_init=self.a_init, |
| 278 | + b_init=self.b_init, |
| 279 | + ) |
| 280 | + return self.wrapped(inputs) + adapter(inputs) |
0 commit comments