|
63 | 63 |
|
64 | 64 | from config import get_config |
65 | 65 |
|
| 66 | +from safetensors import safe_open |
| 67 | + |
66 | 68 | try: |
67 | 69 | from paddleformers.trainer.trainer_utils import log_trainer_start |
68 | 70 | except ImportError: |
@@ -202,6 +204,118 @@ def _collate_data(data, stack_fn=Stack()): |
202 | 204 | return train_dataset, valid_dataset, test_dataset, _collate_data |
203 | 205 |
|
204 | 206 |
|
| 207 | +def load_huggingface_checkpoint(model, args): |
| 208 | + fused_rms_norm_replace = [ |
| 209 | + ("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"), |
| 210 | + ("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"), |
| 211 | + ] |
| 212 | + shared_layers_prefix = "shared_layers.embed_weight_share." |
| 213 | + unnamed_layers = ["ernie.norm.weight", "lm_head.weight"] |
| 214 | + |
| 215 | + logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}") |
| 216 | + with open( |
| 217 | + os.path.join(args.model_name_or_path, "model.safetensors.index.json") |
| 218 | + ) as f: |
| 219 | + weight_map = json.load(f)["weight_map"] |
| 220 | + |
| 221 | + ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size() |
| 222 | + ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank() |
| 223 | + expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank |
| 224 | + |
| 225 | + def param_to_weight(name): |
| 226 | + # for PP=1, we only need to substitute the fused_rms_norm and expert_id |
| 227 | + for src, dst in fused_rms_norm_replace: |
| 228 | + name = name.replace(src, dst) |
| 229 | + if m := re.search(r"mlp\.experts\.(\d+)", name): |
| 230 | + expert_id = expert_offset + int(m.group(1)) |
| 231 | + s, e = m.span() |
| 232 | + name = name[:s] + f"mlp.experts.{expert_id}" + name[e:] |
| 233 | + if isinstance(model, ErnieMoEForCausalLM): |
| 234 | + return name |
| 235 | + |
| 236 | + # for PP>1, we also need to handle special layers and adjust layer_idx |
| 237 | + if name.startswith(shared_layers_prefix): |
| 238 | + return "ernie." + name[len(shared_layers_prefix) :] |
| 239 | + layer_idx, stem = name.split(".", maxsplit=1) |
| 240 | + if stem == "weight": |
| 241 | + return unnamed_layers.pop(0) |
| 242 | + if stem.startswith("mtp"): |
| 243 | + return f"ernie.{stem}" |
| 244 | + return f"ernie.layers.{int(layer_idx) - 1}.{stem}" |
| 245 | + |
| 246 | + def try_torch_format(weight_key): |
| 247 | + if weight_key.startswith("ernie."): |
| 248 | + weight_key = "model." + weight_key[6:] |
| 249 | + |
| 250 | + key_decompose = [weight_key] |
| 251 | + if ".up_gate_proj." in weight_key: |
| 252 | + key_decompose = [ |
| 253 | + weight_key.replace(".up_gate_proj.", ".gate_proj."), |
| 254 | + weight_key.replace(".up_gate_proj.", ".up_proj."), |
| 255 | + ] |
| 256 | + elif ".qkv_proj." in weight_key: |
| 257 | + key_decompose = [ |
| 258 | + weight_key.replace(".qkv_proj.", ".q_proj."), |
| 259 | + weight_key.replace(".qkv_proj.", ".k_proj."), |
| 260 | + weight_key.replace(".qkv_proj.", ".v_proj."), |
| 261 | + ] |
| 262 | + |
| 263 | + tensor_decompose = [] |
| 264 | + for key in key_decompose: |
| 265 | + if not (weight_file := weight_map.get(key)): |
| 266 | + return None |
| 267 | + with safe_open( |
| 268 | + os.path.join(args.model_name_or_path, weight_file), |
| 269 | + framework="paddle", |
| 270 | + device="cpu", |
| 271 | + ) as f: |
| 272 | + tensor = f.get_tensor(key) |
| 273 | + if "_proj." in key or ".gate." in key: |
| 274 | + tensor = tensor.T.contiguous() |
| 275 | + tensor_decompose.append(tensor) |
| 276 | + |
| 277 | + if len(tensor_decompose) == 1: |
| 278 | + return tensor_decompose[0] |
| 279 | + else: |
| 280 | + return paddle.concat(tensor_decompose, axis=-1) |
| 281 | + |
| 282 | + def auto_fix_shape(param, weight): |
| 283 | + assert len(param.shape) == len(weight.shape), "rank not match" |
| 284 | + if ( |
| 285 | + len(param.shape) == 2 |
| 286 | + and param.shape[0] == weight.shape[1] |
| 287 | + and param.shape[1] == weight.shape[0] |
| 288 | + ): |
| 289 | + return weight.T.contiguous() |
| 290 | + assert all( |
| 291 | + p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape) |
| 292 | + ), "weight too small" |
| 293 | + indices = tuple(slice(0, dim) for dim in param.shape) |
| 294 | + return weight[indices].contiguous() |
| 295 | + |
| 296 | + for name, param in model.named_parameters(): |
| 297 | + weight_key = param_to_weight(name) |
| 298 | + if weight_file := weight_map.get(weight_key): |
| 299 | + with safe_open( |
| 300 | + os.path.join(args.model_name_or_path, weight_file), |
| 301 | + framework="paddle", |
| 302 | + ) as f: |
| 303 | + weight = f.get_tensor(weight_key) |
| 304 | + elif (weight := try_torch_format(weight_key)) is None: |
| 305 | + logger.warning( |
| 306 | + f"param `{name}`'s weight `{weight_key}` not found. " |
| 307 | + "Skip initializing." |
| 308 | + ) |
| 309 | + continue |
| 310 | + if param.shape != weight.shape: |
| 311 | + logger.warning( |
| 312 | + f"param `{name}`'s shape doesn't match weight `{weight_key}`: " |
| 313 | + f"{param.shape} and {weight.shape}. Auto fixing." |
| 314 | + ) |
| 315 | + weight = auto_fix_shape(param, weight) |
| 316 | + param.copy_(weight) |
| 317 | + |
| 318 | + |
205 | 319 | def main(): |
206 | 320 | if set_affinity is not None: |
207 | 321 | set_affinity_code = set_affinity() |
@@ -520,21 +634,12 @@ def sname_to_tname(pp_model): |
520 | 634 | cfg.enable_delay_scale_loss = args.enable_delay_scale_loss |
521 | 635 | register_pp_reshard_information(cfg.num_hidden_layers) |
522 | 636 |
|
523 | | - if args.from_scratch: |
524 | | - model = ErnieMoEForCausalLMPipe(cfg) |
525 | | - else: |
526 | | - model = ErnieMoEForCausalLMPipe.from_pretrained( |
527 | | - args.model_name_or_path, |
528 | | - config=cfg, |
529 | | - ) |
| 637 | + model = ErnieMoEForCausalLMPipe(cfg) |
530 | 638 | else: |
531 | | - if args.from_scratch: |
532 | | - model = ErnieMoEForCausalLM(cfg) |
533 | | - else: |
534 | | - model = ErnieMoEForCausalLM.from_pretrained( |
535 | | - args.model_name_or_path, |
536 | | - config=cfg, |
537 | | - ) |
| 639 | + model = ErnieMoEForCausalLM(cfg) |
| 640 | + |
| 641 | + if not args.from_scratch: |
| 642 | + load_huggingface_checkpoint(model, args) |
538 | 643 |
|
539 | 644 | cfg = model.config |
540 | 645 | logger.info(f"using model type:{type(model)}") |
|
0 commit comments