Skip to content

Commit b35dcd9

Browse files
committed
Implement huggingface checkpoint loading
1 parent 850fd92 commit b35dcd9

File tree

1 file changed

+119
-14
lines changed

1 file changed

+119
-14
lines changed

examples/pre-training/ernie/pretrain.py

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363

6464
from config import get_config
6565

66+
from safetensors import safe_open
67+
6668
try:
6769
from paddleformers.trainer.trainer_utils import log_trainer_start
6870
except ImportError:
@@ -164,6 +166,118 @@ def _collate_data(data, stack_fn=Stack()):
164166
return train_dataset, valid_dataset, test_dataset, _collate_data
165167

166168

169+
def load_huggingface_checkpoint(model, args):
170+
fused_rms_norm_replace = [
171+
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
172+
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
173+
]
174+
shared_layers_prefix = "shared_layers.embed_weight_share."
175+
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]
176+
177+
logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}")
178+
with open(
179+
os.path.join(args.model_name_or_path, "model.safetensors.index.json")
180+
) as f:
181+
weight_map = json.load(f)["weight_map"]
182+
183+
ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size()
184+
ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank()
185+
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
186+
187+
def param_to_weight(name):
188+
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
189+
for src, dst in fused_rms_norm_replace:
190+
name = name.replace(src, dst)
191+
if m := re.search(r"mlp\.experts\.(\d+)", name):
192+
expert_id = expert_offset + int(m.group(1))
193+
s, e = m.span()
194+
name = name[:s] + f"mlp.experts.{expert_id}" + name[e:]
195+
if isinstance(model, ErnieMoEForCausalLM):
196+
return name
197+
198+
# for PP>1, we also need to handle special layers and adjust layer_idx
199+
if name.startswith(shared_layers_prefix):
200+
return "ernie." + name[len(shared_layers_prefix) :]
201+
layer_idx, stem = name.split(".", maxsplit=1)
202+
if stem == "weight":
203+
return unnamed_layers.pop(0)
204+
if stem.startswith("mtp"):
205+
return f"ernie.{stem}"
206+
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"
207+
208+
def try_torch_format(weight_key):
209+
if weight_key.startswith("ernie."):
210+
weight_key = "model." + weight_key[6:]
211+
212+
key_decompose = [weight_key]
213+
if ".up_gate_proj." in weight_key:
214+
key_decompose = [
215+
weight_key.replace(".up_gate_proj.", ".gate_proj."),
216+
weight_key.replace(".up_gate_proj.", ".up_proj."),
217+
]
218+
elif ".qkv_proj." in weight_key:
219+
key_decompose = [
220+
weight_key.replace(".qkv_proj.", ".q_proj."),
221+
weight_key.replace(".qkv_proj.", ".k_proj."),
222+
weight_key.replace(".qkv_proj.", ".v_proj."),
223+
]
224+
225+
tensor_decompose = []
226+
for key in key_decompose:
227+
if not (weight_file := weight_map.get(key)):
228+
return None
229+
with safe_open(
230+
os.path.join(args.model_name_or_path, weight_file),
231+
framework="paddle",
232+
device="cpu",
233+
) as f:
234+
tensor = f.get_tensor(key)
235+
if "_proj." in key or ".gate." in key:
236+
tensor = tensor.T.contiguous()
237+
tensor_decompose.append(tensor)
238+
239+
if len(tensor_decompose) == 1:
240+
return tensor_decompose[0]
241+
else:
242+
return paddle.concat(tensor_decompose, axis=-1)
243+
244+
def auto_fix_shape(param, weight):
245+
assert len(param.shape) == len(weight.shape), "rank not match"
246+
if (
247+
len(param.shape) == 2
248+
and param.shape[0] == weight.shape[1]
249+
and param.shape[1] == weight.shape[0]
250+
):
251+
return weight.T.contiguous()
252+
assert all(
253+
p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape)
254+
), "weight too small"
255+
indices = tuple(slice(0, dim) for dim in param.shape)
256+
return weight[indices].contiguous()
257+
258+
for name, param in model.named_parameters():
259+
weight_key = param_to_weight(name)
260+
if weight_file := weight_map.get(weight_key):
261+
with safe_open(
262+
os.path.join(args.model_name_or_path, weight_file),
263+
framework="paddle",
264+
) as f:
265+
weight = f.get_tensor(weight_key)
266+
elif (weight := try_torch_format(weight_key)) is None:
267+
logger.warning(
268+
f"param `{name}`'s weight `{weight_key}` not found. "
269+
"Skip initializing."
270+
)
271+
continue
272+
if param.shape != weight.shape:
273+
logger.warning(
274+
f"param `{name}`'s shape doesn't match weight `{weight_key}`: "
275+
f"{param.shape} and {weight.shape}. Auto fixing."
276+
)
277+
weight = auto_fix_shape(param, weight)
278+
param.copy_(weight)
279+
280+
167281
def main():
168282
if set_affinity is not None:
169283
set_affinity_code = set_affinity()
@@ -482,21 +596,12 @@ def sname_to_tname(pp_model):
482596
cfg.enable_delay_scale_loss = args.enable_delay_scale_loss
483597
register_pp_reshard_information(cfg.num_hidden_layers)
484598

485-
if args.from_scratch:
486-
model = ErnieMoEForCausalLMPipe(cfg)
487-
else:
488-
model = ErnieMoEForCausalLMPipe.from_pretrained(
489-
args.model_name_or_path,
490-
config=cfg,
491-
)
599+
model = ErnieMoEForCausalLMPipe(cfg)
492600
else:
493-
if args.from_scratch:
494-
model = ErnieMoEForCausalLM(cfg)
495-
else:
496-
model = ErnieMoEForCausalLM.from_pretrained(
497-
args.model_name_or_path,
498-
config=cfg,
499-
)
601+
model = ErnieMoEForCausalLM(cfg)
602+
603+
if not args.from_scratch:
604+
load_huggingface_checkpoint(model, args)
500605

501606
cfg = model.config
502607
logger.info(f"using model type:{type(model)}")

0 commit comments

Comments
 (0)