From 84875cb5b948e008ee1e4a7d536e8202893e3cca Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 14 Aug 2025 07:57:50 +0000 Subject: [PATCH 1/6] load hf ckpt --- paddlenlp/trainer/trainer.py | 211 +++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 514402ba3ec7..13ea3927f865 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -93,6 +93,10 @@ ) except: pass +from collections import defaultdict + +from safetensors import safe_open + from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance from ..transformers.model_utils import ( PretrainedModel, @@ -212,6 +216,120 @@ def in_auto_parallel_align_mode(): __all__ = ["Trainer"] +# 预编译正则表达式提升性能 +_LAYER_RE = re.compile(r"^_layers\.(\d+)\.(\d+)(?:\.(.*))?$") +_EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$") +_EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$") + +# 保持原有映射关系 +custom_name_map = { + "self_attn.fused_rms_norm_linear.rms_norm_weight": "input_layernorm.weight", + "self_attn.memory_recompute_att.kv_ln_weigh": "self_attn.kv_a_layernorm.weight", + "self_attn.fused_rms_norm_linear.kv_down_weight": "self_attn.kv_a_proj_with_mqa.weight", + "self_attn.memory_recompute_att.kv_up_weight": "self_attn.kv_b_proj.weight", + "self_attn.memory_recompute_att.q_ln_weight": "self_attn.q_a_layernorm.weight", + "self_attn.fused_rms_norm_linear.q_down_weight": "self_attn.q_a_proj.weight", + "self_attn.memory_recompute_att.q_up_weight": "self_attn.q_b_proj.weight", +} + + +def paddle_name_to_hf_names(paddle_name: str) -> List[str]: + """ + 将Paddle模型参数名称转换为Hugging Face格式的名称列表 + + 参数: + paddle_name: Paddle格式的参数名称 + + 返回: + Hugging Face格式的参数名称列表(可能拆分多个参数) + """ + # 基础路径解析 + m = _LAYER_RE.match(paddle_name) + if not m: + return [] + + segment_id = int(m.group(1)) + id_in_segment = int(m.group(2)) + rest = m.group(3) or "" + + # 1. 生成HF前缀 + hf_prefix = _get_hf_prefix(segment_id, id_in_segment) + + # 2. 处理子路径转换 + if rest in custom_name_map: + return [f"{hf_prefix}.{custom_name_map[rest]}"] + + if expert_names := _handle_expert_weights(hf_prefix, rest): + return expert_names + + if mlp_names := _handle_mlp_weights(hf_prefix, rest): + return mlp_names + + # 3. 默认处理 + return [f"{hf_prefix}.{rest}"] if rest else [hf_prefix] + + +def _get_hf_prefix(segment_id: int, id_in_segment: int) -> str: + """生成Hugging Face格式的层级前缀""" + # 特殊层级映射 + special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (28, 3): "model", (28, 4): "lm_head"} + + if (segment_id, id_in_segment) in special_cases: + return special_cases[(segment_id, id_in_segment)] + + # 通用层级计算 + layer_idx = segment_id + id_in_segment - 1 + return f"model.layers.{layer_idx}" + + +def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + """处理专家网络权重拆分""" + # 处理专家w1权重(拆分为gate_proj和up_proj) + if m := _EXPERT_W1_RE.match(rest): + expert_id = int(m.group(1)) + return [ + f"{hf_prefix}.mlp.experts.{expert_id}.gate_proj.weight", + f"{hf_prefix}.mlp.experts.{expert_id}.up_proj.weight", + ] + + # 处理专家w2权重(映射为down_proj) + if m := _EXPERT_W2_RE.match(rest): + expert_id = int(m.group(1)) + return [f"{hf_prefix}.mlp.experts.{expert_id}.down_proj.weight"] + + return None + + +def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + """处理普通MLP权重拆分""" + if rest == "mlp.w1": + return [f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"] + + if rest == "mlp.w2": + return [f"{hf_prefix}.mlp.down_proj.weight"] + + return None + + +def prepare_tensor(tensor, dst_shape): + if isinstance(tensor, list): + return paddle.concat( + [ + paddle.transpose(tensor[0], perm=[1, 0]).contiguous(), + paddle.transpose(tensor[1], perm=[1, 0]).contiguous(), + ], + axis=-1, + ) + if tensor.shape == dst_shape: + return tensor + if len(tensor.shape) == 2 and paddle.transpose(tensor, perm=[1, 0]).contiguous().shape == dst_shape: + return paddle.transpose(tensor, perm=[1, 0]).contiguous() + if len(tensor.shape) == 1: + return tensor[0 : dst_shape[0]] + if len(tensor.shape) == 2: + return paddle.transpose(tensor, perm=[1, 0]).contiguous()[:, 0 : dst_shape[1]] + + class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PaddlePaddle, optimized for PaddleNLP. @@ -1009,6 +1127,99 @@ def _inner_training_loop( if self.args.ignore_data_skip: self.timers and self.timers("read-data").start() + print("================================== load safe tensor ==================================") + print("---- paddle param ----") + if self.state.global_step == 0: + for n, p in model.named_parameters(): + print("{}:{}".format(n, p.shape)) + + # 1. 加载参数-文件映射表 + weight_map_path = "/root/paddlejob/workspace/env_run/zhangbo/model.safetensors.index.json" + with open(weight_map_path, "r") as f: + weight_map = json.load(f)["weight_map"] + print("weight_map: ", weight_map) + + # 2. 创建反向索引:文件 -> 参数列表 + file_to_params = defaultdict(list) + for param_name, filename in weight_map.items(): + file_to_params[filename].append(param_name) + + # 2. 收集模型需要的文件列表 + required_files = set() + file_to_pd_param_name = defaultdict(list) + pd_param_name_to_file = defaultdict(list) + for pd_name, _ in model.named_parameters(): + hf_name = paddle_name_to_hf_names(pd_name) + print("pd_name: ", pd_name) + print("hf_name: ", hf_name) + if hf_name[0] in weight_map: + filename = weight_map[hf_name[0]] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + pd_param_name_to_file[pd_name].append(filename) + if len(hf_name) > 1 and hf_name[1] in weight_map: + filename = weight_map[hf_name[1]] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + if filename != pd_param_name_to_file[pd_name][0]: + pd_param_name_to_file[pd_name].append(filename) + else: + print(f"Warning: {pd_name} not found in weight map") + print("---- required_files ----") + print(required_files) + print("---- file_to_pd_param_name ----") + print(file_to_pd_param_name) + print("---- pd_param_name_to_file ----") + print(pd_param_name_to_file) + + # 3. 按文件分组加载 + ckpt_pre = "/root/paddlejob/new_disk/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/" + check_list = [] + print("---- start load param ----") + for filename in required_files: + try: + with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f: + print("open file: ", ckpt_pre + filename) + # 加载该文件包含的所有参数 + pd_params = file_to_pd_param_name[filename] + print("load for params: ", pd_params) + for pd_param in pd_params: + if pd_param in check_list: + continue + hf_name = paddle_name_to_hf_names(pd_param) + if len(hf_name) == 1: + tensor = f.get_tensor(hf_name[0]) + model.state_dict()[pd_param].set_value( + paddle.cast( + prepare_tensor(tensor, model.state_dict()[pd_param].shape), + model.state_dict()[pd_param].dtype, + ) + ) + else: + files = pd_param_name_to_file[pd_name] + if len(files) == 1: + tensor0 = f.get_tensor(hf_name[0]) + tensor1 = f.get_tensor(hf_name[1]) + else: + if weight_map[hf_name[0]] == filename: + tensor0 = f.get_tensor(hf_name[0]) + with safe_open( + ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu" + ) as f_other: + tensor1 = f_other.get_tensor(hf_name[1]) + else: + with safe_open( + ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu" + ) as f_other: + tensor0 = f_other.get_tensor(hf_name[1]) + tensor1 = f.get_tensor(hf_name[1]) + model.state_dict()[pd_param].set_value(prepare_tensor([tensor0, tensor1], None)) + check_list.append(pd_param) + + except Exception as e: + print(f"Error loading {filename}: {str(e)}") + raise + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( train_dataloader.batch_sampler, DistributedBatchSampler From 76cb10b1484d81cb1e4706c3555297232db3db67 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 14 Aug 2025 09:27:56 +0000 Subject: [PATCH 2/6] load hf ckpt --- paddlenlp/trainer/trainer.py | 37 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 13ea3927f865..315701ba1107 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -93,10 +93,6 @@ ) except: pass -from collections import defaultdict - -from safetensors import safe_open - from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance from ..transformers.model_utils import ( PretrainedModel, @@ -216,6 +212,10 @@ def in_auto_parallel_align_mode(): __all__ = ["Trainer"] +from collections import defaultdict + +from safetensors import safe_open + # 预编译正则表达式提升性能 _LAYER_RE = re.compile(r"^_layers\.(\d+)\.(\d+)(?:\.(.*))?$") _EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$") @@ -1137,7 +1137,6 @@ def _inner_training_loop( weight_map_path = "/root/paddlejob/workspace/env_run/zhangbo/model.safetensors.index.json" with open(weight_map_path, "r") as f: weight_map = json.load(f)["weight_map"] - print("weight_map: ", weight_map) # 2. 创建反向索引:文件 -> 参数列表 file_to_params = defaultdict(list) @@ -1150,27 +1149,23 @@ def _inner_training_loop( pd_param_name_to_file = defaultdict(list) for pd_name, _ in model.named_parameters(): hf_name = paddle_name_to_hf_names(pd_name) - print("pd_name: ", pd_name) - print("hf_name: ", hf_name) if hf_name[0] in weight_map: filename = weight_map[hf_name[0]] required_files.add(filename) file_to_pd_param_name[filename].append(pd_name) pd_param_name_to_file[pd_name].append(filename) - if len(hf_name) > 1 and hf_name[1] in weight_map: - filename = weight_map[hf_name[1]] - required_files.add(filename) - file_to_pd_param_name[filename].append(pd_name) - if filename != pd_param_name_to_file[pd_name][0]: - pd_param_name_to_file[pd_name].append(filename) else: - print(f"Warning: {pd_name} not found in weight map") - print("---- required_files ----") - print(required_files) - print("---- file_to_pd_param_name ----") - print(file_to_pd_param_name) - print("---- pd_param_name_to_file ----") - print(pd_param_name_to_file) + print(f"Warning: {pd_name} -> {hf_name[0]} not found in weight map") + + if len(hf_name) > 1: + if hf_name[1] in weight_map: + filename = weight_map[hf_name[1]] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + if filename != pd_param_name_to_file[pd_name][0]: + pd_param_name_to_file[pd_name].append(filename) + else: + print(f"Warning: {pd_name} -> {hf_name[1]} not found in weight map") # 3. 按文件分组加载 ckpt_pre = "/root/paddlejob/new_disk/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/" @@ -1182,10 +1177,10 @@ def _inner_training_loop( print("open file: ", ckpt_pre + filename) # 加载该文件包含的所有参数 pd_params = file_to_pd_param_name[filename] - print("load for params: ", pd_params) for pd_param in pd_params: if pd_param in check_list: continue + print("load for pd_param: ", pd_param) hf_name = paddle_name_to_hf_names(pd_param) if len(hf_name) == 1: tensor = f.get_tensor(hf_name[0]) From 79c2d38588879b0db96d2ebfd43120e06bd8cb99 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 14 Aug 2025 11:08:44 +0000 Subject: [PATCH 3/6] fix --- paddlenlp/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 315701ba1107..019ad2face24 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1206,7 +1206,7 @@ def _inner_training_loop( with safe_open( ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu" ) as f_other: - tensor0 = f_other.get_tensor(hf_name[1]) + tensor0 = f_other.get_tensor(hf_name[0]) tensor1 = f.get_tensor(hf_name[1]) model.state_dict()[pd_param].set_value(prepare_tensor([tensor0, tensor1], None)) check_list.append(pd_param) From 8a148085ff202ab6f43b19dd5d506897273b3f8b Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 14 Aug 2025 11:51:13 +0000 Subject: [PATCH 4/6] fix --- paddlenlp/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 019ad2face24..3b18692181c7 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1191,7 +1191,7 @@ def _inner_training_loop( ) ) else: - files = pd_param_name_to_file[pd_name] + files = pd_param_name_to_file[pd_param] if len(files) == 1: tensor0 = f.get_tensor(hf_name[0]) tensor1 = f.get_tensor(hf_name[1]) From a4bf4a3354ec607716af0ca867b0c704f59c4c02 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 14 Aug 2025 12:04:07 +0000 Subject: [PATCH 5/6] fix --- paddlenlp/trainer/trainer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 3b18692181c7..2db095d3cab7 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -283,7 +283,6 @@ def _get_hf_prefix(segment_id: int, id_in_segment: int) -> str: def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: - """处理专家网络权重拆分""" # 处理专家w1权重(拆分为gate_proj和up_proj) if m := _EXPERT_W1_RE.match(rest): expert_id = int(m.group(1)) @@ -301,7 +300,6 @@ def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: - """处理普通MLP权重拆分""" if rest == "mlp.w1": return [f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"] @@ -1127,12 +1125,6 @@ def _inner_training_loop( if self.args.ignore_data_skip: self.timers and self.timers("read-data").start() - print("================================== load safe tensor ==================================") - print("---- paddle param ----") - if self.state.global_step == 0: - for n, p in model.named_parameters(): - print("{}:{}".format(n, p.shape)) - # 1. 加载参数-文件映射表 weight_map_path = "/root/paddlejob/workspace/env_run/zhangbo/model.safetensors.index.json" with open(weight_map_path, "r") as f: @@ -1174,7 +1166,6 @@ def _inner_training_loop( for filename in required_files: try: with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f: - print("open file: ", ckpt_pre + filename) # 加载该文件包含的所有参数 pd_params = file_to_pd_param_name[filename] for pd_param in pd_params: From 04840dad1ed1d4f7c4429e4b4f9b729cc00d944b Mon Sep 17 00:00:00 2001 From: phlrain <--global> Date: Sat, 16 Aug 2025 14:45:59 +0800 Subject: [PATCH 6/6] fix model load bug --- paddlenlp/trainer/trainer.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 2db095d3cab7..223de77b91ce 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -220,11 +220,13 @@ def in_auto_parallel_align_mode(): _LAYER_RE = re.compile(r"^_layers\.(\d+)\.(\d+)(?:\.(.*))?$") _EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$") _EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$") +_SHARE_EXPERT_W1_RE = re.compile(r"^mlp\.shared_experts\.w1(?:\.weight)?$") +_SHARE_EXPERT_W2_RE = re.compile(r"^mlp\.shared_experts\.w2(?:\.weight)?$") # 保持原有映射关系 custom_name_map = { "self_attn.fused_rms_norm_linear.rms_norm_weight": "input_layernorm.weight", - "self_attn.memory_recompute_att.kv_ln_weigh": "self_attn.kv_a_layernorm.weight", + "self_attn.memory_recompute_att.kv_ln_weight": "self_attn.kv_a_layernorm.weight", "self_attn.fused_rms_norm_linear.kv_down_weight": "self_attn.kv_a_proj_with_mqa.weight", "self_attn.memory_recompute_att.kv_up_weight": "self_attn.kv_b_proj.weight", "self_attn.memory_recompute_att.q_ln_weight": "self_attn.q_a_layernorm.weight", @@ -244,6 +246,8 @@ def paddle_name_to_hf_names(paddle_name: str) -> List[str]: Hugging Face格式的参数名称列表(可能拆分多个参数) """ # 基础路径解析 + if paddle_name == "_layers.local_shared_layers.DeepseekV2_shared_weight.embed_tokens.weight": + return ["model.embed_tokens.weight"] m = _LAYER_RE.match(paddle_name) if not m: return [] @@ -261,6 +265,8 @@ def paddle_name_to_hf_names(paddle_name: str) -> List[str]: if expert_names := _handle_expert_weights(hf_prefix, rest): return expert_names + if shared_mlp_names := _handle_shared_expert_weights(hf_prefix, rest): + return shared_mlp_names if mlp_names := _handle_mlp_weights(hf_prefix, rest): return mlp_names @@ -298,6 +304,19 @@ def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: return None +def _handle_shared_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + # 处理专家w1权重(拆分为gate_proj和up_proj) + if m := _SHARE_EXPERT_W1_RE.match(rest): + return [ + f"{hf_prefix}.mlp.shared_experts.gate_proj.weight", + f"{hf_prefix}.mlp.shared_experts.up_proj.weight", + ] + + # 处理专家w2权重(映射为down_proj) + if m := _SHARE_EXPERT_W2_RE.match(rest): + return [f"{hf_prefix}.mlp.shared_experts.down_proj.weight"] + + return None def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: if rest == "mlp.w1": @@ -1148,6 +1167,9 @@ def _inner_training_loop( pd_param_name_to_file[pd_name].append(filename) else: print(f"Warning: {pd_name} -> {hf_name[0]} not found in weight map") + import sys + sys.exit() + if len(hf_name) > 1: if hf_name[1] in weight_map: