Skip to content

Commit db6e197

Browse files
wwkongThe gemma Authors
authored andcommitted
Implement a LoRA interceptor for nn.DenseGeneral.
PiperOrigin-RevId: 774791006
1 parent e552410 commit db6e197

File tree

3 files changed

+130
-1
lines changed

3 files changed

+130
-1
lines changed

gemma/peft/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from gemma.peft._lora import LoRADenseAdapter
2626
from gemma.peft._lora import LoRAEinsum
2727
from gemma.peft._lora import LoRAEinsumAdapter
28+
from gemma.peft._lora import LoRADenseGeneral
29+
from gemma.peft._lora import LoRADenseGeneralAdapter
2830
from gemma.peft._tree_utils import fuse_params
2931
from gemma.peft._tree_utils import merge_params
3032
from gemma.peft._tree_utils import split_params

gemma/peft/_lora.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from flax import linen as nn
2121
from flax.typing import Array # pylint: disable=g-importing-member
2222
from gemma.peft import _einsum_utils
23+
import jax
2324
import jax.numpy as jnp
25+
import numpy as np
2426

2527

2628
class LoRADenseAdapter(nn.Module):
@@ -170,3 +172,109 @@ def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array:
170172
b_init=self.b_init,
171173
)
172174
return self.wrapped(inputs) + adapter(inputs)
175+
176+
177+
class LoRADenseGeneralAdapter(nn.Module):
178+
"""LoRA general dense module.
179+
180+
This module only does the x @ A @ B computation.
181+
Use `LoRAGeneralDense` to wrap a `nn.Einsum` layer.
182+
183+
Attributes:
184+
rank: The rank of the LoRA decomposition.
185+
features: The number of output features.
186+
axis: int or tuple with axes to apply the transformation on.
187+
batch_dims: tuple with batch axes.
188+
dtype: The dtype to use for the LoRA weights.
189+
a_init: The initializer for the A matrix.
190+
b_init: The initializer for the B matrix.
191+
"""
192+
193+
_: dataclasses.KW_ONLY
194+
195+
rank: int
196+
features: int | Sequence[int]
197+
axis: int | Sequence[int]
198+
batch_dims: Sequence[int]
199+
200+
dtype: jnp.dtype = jnp.float_
201+
a_init: nn.initializers.Initializer = nn.initializers.kaiming_uniform()
202+
b_init: nn.initializers.Initializer = nn.initializers.zeros_init()
203+
204+
@nn.compact
205+
def __call__(self, inputs: Array) -> Array:
206+
"""Mostly copied from `flax.nn.DenseGeneral`."""
207+
208+
# Normalize inputs
209+
batch_dims = nn.linear._canonicalize_tuple(self.batch_dims)
210+
if batch_dims:
211+
max_dim = np.max(batch_dims)
212+
if set(batch_dims) != set(range(max_dim + 1)):
213+
raise ValueError(
214+
'batch_dims %s must be consecutive leading '
215+
'dimensions starting from 0.'
216+
% str(batch_dims)
217+
)
218+
219+
n_dim = inputs.ndim
220+
batch_dims = nn.linear._normalize_axes(batch_dims, n_dim)
221+
features = nn.linear._canonicalize_tuple(self.features)
222+
axis = nn.linear._normalize_axes(
223+
nn.linear._canonicalize_tuple(self.axis),
224+
n_dim,
225+
)
226+
227+
# Create LoRA params
228+
batch_shape = tuple(inputs.shape[i] for i in batch_dims)
229+
a_shape = batch_shape + tuple(inputs.shape[i] for i in axis) + (self.rank,)
230+
a = self.param('a', self.a_init, a_shape, self.dtype)
231+
b_shape = (*batch_shape, self.rank, *features)
232+
b = self.param('b', self.b_init, b_shape, self.dtype)
233+
234+
# Contract across given axes.
235+
n_batch_dims, n_axis = len(batch_dims), len(axis)
236+
batch_ind = tuple(range(n_batch_dims))
237+
contract_ind = tuple(range(n_batch_dims, n_batch_dims + n_axis))
238+
inputs = nn.dtypes.promote_dtype(inputs, dtype=self.dtype)[0]
239+
# low_rank = x @ A
240+
low_rank_dot = ((axis, contract_ind), (batch_dims, batch_ind))
241+
low_rank = jax.lax.dot_general(inputs, a, low_rank_dot)
242+
# out = low_rank @ B
243+
low_rank_cind = [n_dim - n_axis]
244+
b_cind = [n_batch_dims]
245+
out_dot = ((low_rank_cind, b_cind), (batch_dims, batch_ind))
246+
return jax.lax.dot_general(low_rank, b, out_dot)
247+
248+
249+
class LoRADenseGeneral(nn.Module):
250+
"""Wrapper around `nn.DenseGeneral` which adds a LoRA adapter."""
251+
252+
_: dataclasses.KW_ONLY
253+
254+
rank: int
255+
wrapped: nn.DenseGeneral
256+
257+
dtype: jnp.dtype = jnp.float_
258+
a_init: nn.initializers.Initializer = nn.initializers.kaiming_uniform()
259+
b_init: nn.initializers.Initializer = nn.initializers.zeros_init()
260+
261+
def __post_init__(self):
262+
super().__post_init__()
263+
# Share scope, to make the wrapper module transparent with respect to the
264+
# parameters (instead of nesting `{'params': {'wrapped': params}}`).
265+
if self.scope is not None:
266+
nn.share_scope(self, self.wrapped)
267+
268+
@nn.compact
269+
def __call__(self, inputs: Array) -> Array:
270+
adapter = LoRADenseGeneralAdapter(
271+
name='lora',
272+
rank=self.rank,
273+
features=self.wrapped.features,
274+
axis=self.wrapped.axis,
275+
batch_dims=self.wrapped.batch_dims,
276+
dtype=self.dtype,
277+
a_init=self.a_init,
278+
b_init=self.b_init,
279+
)
280+
return self.wrapped(inputs) + adapter(inputs)

gemma/peft/_lora_test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
def _dense_to_lora(module):
2424
if isinstance(module, nn.Dense):
2525
return peft.LoRADense(rank=1, wrapped=module)
26-
if isinstance(module, nn.Einsum):
26+
elif isinstance(module, nn.Einsum):
2727
return peft.LoRAEinsum(rank=1, wrapped=module)
28+
elif isinstance(module, nn.DenseGeneral):
29+
return peft.LoRADenseGeneral(rank=1, wrapped=module)
2830
else:
2931
return module
3032

@@ -46,11 +48,19 @@ def __call__(self, x):
4648
shape=(4, 2, 3),
4749
einsum_str='bi,imn->bmn',
4850
)(x)
51+
52+
# Test DenseGeneral
53+
y4 = nn.DenseGeneral(
54+
features=(2, 3),
55+
axis=-1,
56+
)(x)
57+
4958
return {
5059
'y0': y0,
5160
'y1': y1,
5261
'y2': y2,
5362
'y3': y3,
63+
'y4': y4,
5464
}
5565

5666

@@ -88,11 +98,20 @@ def test_lora():
8898
'b': f32[1, 2, 3],
8999
},
90100
},
101+
'DenseGeneral_0': {
102+
'kernel': f32[4, 2, 3],
103+
'bias': f32[2, 3],
104+
'lora': {
105+
'a': f32[4, 1],
106+
'b': f32[1, 2, 3],
107+
},
108+
},
91109
},
92110
})
93111
assert etree.spec_like(out) == etree.spec_like({
94112
'y0': f32[1, 2],
95113
'y1': f32[1, 2],
96114
'y2': f32[1, 3],
97115
'y3': f32[1, 2, 3],
116+
'y4': f32[1, 2, 3],
98117
})

0 commit comments

Comments
 (0)