|
63 | 63 |
|
64 | 64 | from config import get_config |
65 | 65 |
|
| 66 | +from safetensors import safe_open |
| 67 | + |
66 | 68 | try: |
67 | 69 | from paddleformers.trainer.trainer_utils import log_trainer_start |
68 | 70 | except ImportError: |
@@ -164,6 +166,118 @@ def _collate_data(data, stack_fn=Stack()): |
164 | 166 | return train_dataset, valid_dataset, test_dataset, _collate_data |
165 | 167 |
|
166 | 168 |
|
| 169 | +def load_huggingface_checkpoint(model, args): |
| 170 | + fused_rms_norm_replace = [ |
| 171 | + ("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"), |
| 172 | + ("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"), |
| 173 | + ] |
| 174 | + shared_layers_prefix = "shared_layers.embed_weight_share." |
| 175 | + unnamed_layers = ["ernie.norm.weight", "lm_head.weight"] |
| 176 | + |
| 177 | + logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}") |
| 178 | + with open( |
| 179 | + os.path.join(args.model_name_or_path, "model.safetensors.index.json") |
| 180 | + ) as f: |
| 181 | + weight_map = json.load(f)["weight_map"] |
| 182 | + |
| 183 | + ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size() |
| 184 | + ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank() |
| 185 | + expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank |
| 186 | + |
| 187 | + def param_to_weight(name): |
| 188 | + # for PP=1, we only need to substitute the fused_rms_norm and expert_id |
| 189 | + for src, dst in fused_rms_norm_replace: |
| 190 | + name = name.replace(src, dst) |
| 191 | + if m := re.search(r"mlp\.experts\.(\d+)", name): |
| 192 | + expert_id = expert_offset + int(m.group(1)) |
| 193 | + s, e = m.span() |
| 194 | + name = name[:s] + f"mlp.experts.{expert_id}" + name[e:] |
| 195 | + if isinstance(model, ErnieMoEForCausalLM): |
| 196 | + return name |
| 197 | + |
| 198 | + # for PP>1, we also need to handle special layers and adjust layer_idx |
| 199 | + if name.startswith(shared_layers_prefix): |
| 200 | + return "ernie." + name[len(shared_layers_prefix) :] |
| 201 | + layer_idx, stem = name.split(".", maxsplit=1) |
| 202 | + if stem == "weight": |
| 203 | + return unnamed_layers.pop(0) |
| 204 | + if stem.startswith("mtp"): |
| 205 | + return f"ernie.{stem}" |
| 206 | + return f"ernie.layers.{int(layer_idx) - 1}.{stem}" |
| 207 | + |
| 208 | + def try_torch_format(weight_key): |
| 209 | + if weight_key.startswith("ernie."): |
| 210 | + weight_key = "model." + weight_key[6:] |
| 211 | + |
| 212 | + key_decompose = [weight_key] |
| 213 | + if ".up_gate_proj." in weight_key: |
| 214 | + key_decompose = [ |
| 215 | + weight_key.replace(".up_gate_proj.", ".gate_proj."), |
| 216 | + weight_key.replace(".up_gate_proj.", ".up_proj."), |
| 217 | + ] |
| 218 | + elif ".qkv_proj." in weight_key: |
| 219 | + key_decompose = [ |
| 220 | + weight_key.replace(".qkv_proj.", ".q_proj."), |
| 221 | + weight_key.replace(".qkv_proj.", ".k_proj."), |
| 222 | + weight_key.replace(".qkv_proj.", ".v_proj."), |
| 223 | + ] |
| 224 | + |
| 225 | + tensor_decompose = [] |
| 226 | + for key in key_decompose: |
| 227 | + if not (weight_file := weight_map.get(key)): |
| 228 | + return None |
| 229 | + with safe_open( |
| 230 | + os.path.join(args.model_name_or_path, weight_file), |
| 231 | + framework="paddle", |
| 232 | + device="cpu", |
| 233 | + ) as f: |
| 234 | + tensor = f.get_tensor(key) |
| 235 | + if "_proj." in key or ".gate." in key: |
| 236 | + tensor = tensor.T.contiguous() |
| 237 | + tensor_decompose.append(tensor) |
| 238 | + |
| 239 | + if len(tensor_decompose) == 1: |
| 240 | + return tensor_decompose[0] |
| 241 | + else: |
| 242 | + return paddle.concat(tensor_decompose, axis=-1) |
| 243 | + |
| 244 | + def auto_fix_shape(param, weight): |
| 245 | + assert len(param.shape) == len(weight.shape), "rank not match" |
| 246 | + if ( |
| 247 | + len(param.shape) == 2 |
| 248 | + and param.shape[0] == weight.shape[1] |
| 249 | + and param.shape[1] == weight.shape[0] |
| 250 | + ): |
| 251 | + return weight.T.contiguous() |
| 252 | + assert all( |
| 253 | + p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape) |
| 254 | + ), "weight too small" |
| 255 | + indices = tuple(slice(0, dim) for dim in param.shape) |
| 256 | + return weight[indices].contiguous() |
| 257 | + |
| 258 | + for name, param in model.named_parameters(): |
| 259 | + weight_key = param_to_weight(name) |
| 260 | + if weight_file := weight_map.get(weight_key): |
| 261 | + with safe_open( |
| 262 | + os.path.join(args.model_name_or_path, weight_file), |
| 263 | + framework="paddle", |
| 264 | + ) as f: |
| 265 | + weight = f.get_tensor(weight_key) |
| 266 | + elif (weight := try_torch_format(weight_key)) is None: |
| 267 | + logger.warning( |
| 268 | + f"param `{name}`'s weight `{weight_key}` not found. " |
| 269 | + "Skip initializing." |
| 270 | + ) |
| 271 | + continue |
| 272 | + if param.shape != weight.shape: |
| 273 | + logger.warning( |
| 274 | + f"param `{name}`'s shape doesn't match weight `{weight_key}`: " |
| 275 | + f"{param.shape} and {weight.shape}. Auto fixing." |
| 276 | + ) |
| 277 | + weight = auto_fix_shape(param, weight) |
| 278 | + param.copy_(weight) |
| 279 | + |
| 280 | + |
167 | 281 | def main(): |
168 | 282 | if set_affinity is not None: |
169 | 283 | set_affinity_code = set_affinity() |
@@ -482,21 +596,12 @@ def sname_to_tname(pp_model): |
482 | 596 | cfg.enable_delay_scale_loss = args.enable_delay_scale_loss |
483 | 597 | register_pp_reshard_information(cfg.num_hidden_layers) |
484 | 598 |
|
485 | | - if args.from_scratch: |
486 | | - model = ErnieMoEForCausalLMPipe(cfg) |
487 | | - else: |
488 | | - model = ErnieMoEForCausalLMPipe.from_pretrained( |
489 | | - args.model_name_or_path, |
490 | | - config=cfg, |
491 | | - ) |
| 599 | + model = ErnieMoEForCausalLMPipe(cfg) |
492 | 600 | else: |
493 | | - if args.from_scratch: |
494 | | - model = ErnieMoEForCausalLM(cfg) |
495 | | - else: |
496 | | - model = ErnieMoEForCausalLM.from_pretrained( |
497 | | - args.model_name_or_path, |
498 | | - config=cfg, |
499 | | - ) |
| 601 | + model = ErnieMoEForCausalLM(cfg) |
| 602 | + |
| 603 | + if not args.from_scratch: |
| 604 | + load_huggingface_checkpoint(model, args) |
500 | 605 |
|
501 | 606 | cfg = model.config |
502 | 607 | logger.info(f"using model type:{type(model)}") |
|
0 commit comments