Skip to content

Commit c4c77d4

Browse files
committed
Implement huggingface checkpoint loading
1 parent 2ca3dda commit c4c77d4

File tree

1 file changed

+119
-14
lines changed

1 file changed

+119
-14
lines changed

examples/pre-training/ernie/pretrain.py

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363

6464
from config import get_config
6565

66+
from safetensors import safe_open
67+
6668
try:
6769
from paddleformers.trainer.trainer_utils import log_trainer_start
6870
except ImportError:
@@ -202,6 +204,118 @@ def _collate_data(data, stack_fn=Stack()):
202204
return train_dataset, valid_dataset, test_dataset, _collate_data
203205

204206

207+
def load_huggingface_checkpoint(model, args):
208+
fused_rms_norm_replace = [
209+
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
210+
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
211+
]
212+
shared_layers_prefix = "shared_layers.embed_weight_share."
213+
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]
214+
215+
logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}")
216+
with open(
217+
os.path.join(args.model_name_or_path, "model.safetensors.index.json")
218+
) as f:
219+
weight_map = json.load(f)["weight_map"]
220+
221+
ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size()
222+
ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank()
223+
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
224+
225+
def param_to_weight(name):
226+
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
227+
for src, dst in fused_rms_norm_replace:
228+
name = name.replace(src, dst)
229+
if m := re.search(r"mlp\.experts\.(\d+)", name):
230+
expert_id = expert_offset + int(m.group(1))
231+
s, e = m.span()
232+
name = name[:s] + f"mlp.experts.{expert_id}" + name[e:]
233+
if isinstance(model, ErnieMoEForCausalLM):
234+
return name
235+
236+
# for PP>1, we also need to handle special layers and adjust layer_idx
237+
if name.startswith(shared_layers_prefix):
238+
return "ernie." + name[len(shared_layers_prefix) :]
239+
layer_idx, stem = name.split(".", maxsplit=1)
240+
if stem == "weight":
241+
return unnamed_layers.pop(0)
242+
if stem.startswith("mtp"):
243+
return f"ernie.{stem}"
244+
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"
245+
246+
def try_torch_format(weight_key):
247+
if weight_key.startswith("ernie."):
248+
weight_key = "model." + weight_key[6:]
249+
250+
key_decompose = [weight_key]
251+
if ".up_gate_proj." in weight_key:
252+
key_decompose = [
253+
weight_key.replace(".up_gate_proj.", ".gate_proj."),
254+
weight_key.replace(".up_gate_proj.", ".up_proj."),
255+
]
256+
elif ".qkv_proj." in weight_key:
257+
key_decompose = [
258+
weight_key.replace(".qkv_proj.", ".q_proj."),
259+
weight_key.replace(".qkv_proj.", ".k_proj."),
260+
weight_key.replace(".qkv_proj.", ".v_proj."),
261+
]
262+
263+
tensor_decompose = []
264+
for key in key_decompose:
265+
if not (weight_file := weight_map.get(key)):
266+
return None
267+
with safe_open(
268+
os.path.join(args.model_name_or_path, weight_file),
269+
framework="paddle",
270+
device="cpu",
271+
) as f:
272+
tensor = f.get_tensor(key)
273+
if "_proj." in key or ".gate." in key:
274+
tensor = tensor.T.contiguous()
275+
tensor_decompose.append(tensor)
276+
277+
if len(tensor_decompose) == 1:
278+
return tensor_decompose[0]
279+
else:
280+
return paddle.concat(tensor_decompose, axis=-1)
281+
282+
def auto_fix_shape(param, weight):
283+
assert len(param.shape) == len(weight.shape), "rank not match"
284+
if (
285+
len(param.shape) == 2
286+
and param.shape[0] == weight.shape[1]
287+
and param.shape[1] == weight.shape[0]
288+
):
289+
return weight.T.contiguous()
290+
assert all(
291+
p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape)
292+
), "weight too small"
293+
indices = tuple(slice(0, dim) for dim in param.shape)
294+
return weight[indices].contiguous()
295+
296+
for name, param in model.named_parameters():
297+
weight_key = param_to_weight(name)
298+
if weight_file := weight_map.get(weight_key):
299+
with safe_open(
300+
os.path.join(args.model_name_or_path, weight_file),
301+
framework="paddle",
302+
) as f:
303+
weight = f.get_tensor(weight_key)
304+
elif (weight := try_torch_format(weight_key)) is None:
305+
logger.warning(
306+
f"param `{name}`'s weight `{weight_key}` not found. "
307+
"Skip initializing."
308+
)
309+
continue
310+
if param.shape != weight.shape:
311+
logger.warning(
312+
f"param `{name}`'s shape doesn't match weight `{weight_key}`: "
313+
f"{param.shape} and {weight.shape}. Auto fixing."
314+
)
315+
weight = auto_fix_shape(param, weight)
316+
param.copy_(weight)
317+
318+
205319
def main():
206320
if set_affinity is not None:
207321
set_affinity_code = set_affinity()
@@ -520,21 +634,12 @@ def sname_to_tname(pp_model):
520634
cfg.enable_delay_scale_loss = args.enable_delay_scale_loss
521635
register_pp_reshard_information(cfg.num_hidden_layers)
522636

523-
if args.from_scratch:
524-
model = ErnieMoEForCausalLMPipe(cfg)
525-
else:
526-
model = ErnieMoEForCausalLMPipe.from_pretrained(
527-
args.model_name_or_path,
528-
config=cfg,
529-
)
637+
model = ErnieMoEForCausalLMPipe(cfg)
530638
else:
531-
if args.from_scratch:
532-
model = ErnieMoEForCausalLM(cfg)
533-
else:
534-
model = ErnieMoEForCausalLM.from_pretrained(
535-
args.model_name_or_path,
536-
config=cfg,
537-
)
639+
model = ErnieMoEForCausalLM(cfg)
640+
641+
if not args.from_scratch:
642+
load_huggingface_checkpoint(model, args)
538643

539644
cfg = model.config
540645
logger.info(f"using model type:{type(model)}")

0 commit comments

Comments
 (0)