From f6cefedfc313ccaef6f16befeb9134554e7fb834 Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Mon, 3 Nov 2025 04:47:49 +0000 Subject: [PATCH] adapt fc --- examples/pre-training/ernie/pretrain.py | 23 +- .../ernie/src/trainers/pretraining_trainer.py | 2 +- .../models/ernie/configuration.py | 3 +- .../pre-training/models/ernie/modeling.py | 654 ++++++++++++++---- .../pre-training/models/ernie/modeling_moe.py | 81 ++- .../pre-training/models/ernie/modeling_pp.py | 219 ++++-- examples/pre-training/models/moe/moe_layer.py | 23 +- .../models/sequence_parallel_utils.py | 20 +- 8 files changed, 802 insertions(+), 223 deletions(-) diff --git a/examples/pre-training/ernie/pretrain.py b/examples/pre-training/ernie/pretrain.py index d5486415d..f5a84384b 100644 --- a/examples/pre-training/ernie/pretrain.py +++ b/examples/pre-training/ernie/pretrain.py @@ -296,16 +296,16 @@ def formatv(v): and not args.overwrite_output_dir ): last_checkpoint = get_last_checkpoint(args.output_dir) - if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0: - raise ValueError( - f"Output directory ({args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) + # if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0: + # raise ValueError( + # f"Output directory ({args.output_dir}) already exists and is not empty. " + # "Use --overwrite_output_dir to overcome." + # ) + # elif last_checkpoint is not None and args.resume_from_checkpoint is None: + # logger.info( + # f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + # "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + # ) def compute_metrics(p): preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions @@ -439,6 +439,7 @@ def sname_to_tname(pp_model): cfg.token_balance_seqlen = args.max_seq_length * args.per_device_train_batch_size cfg.fp16_opt_level = args.fp16_opt_level cfg.moe_group = args.moe_group + cfg.moe_group_name = args.moe_group cfg.dtype = dtype cfg.use_fp8 = args.use_fp8 cfg.enable_mtp_magic_send = args.enable_mtp_magic_send @@ -502,7 +503,7 @@ def sname_to_tname(pp_model): logger.info(f"using model type:{type(model)}") paddle.set_default_dtype("float32") - logger.info(f"using model={type(model)}, cfg={cfg}") + # logger.info(f"using model={type(model)}, cfg={cfg}") train_dataset, eval_dataset, test_dataset, data_collator = ( create_pretrained_dataset(args) diff --git a/examples/pre-training/ernie/src/trainers/pretraining_trainer.py b/examples/pre-training/ernie/src/trainers/pretraining_trainer.py index 9308c69a2..9a811cc71 100644 --- a/examples/pre-training/ernie/src/trainers/pretraining_trainer.py +++ b/examples/pre-training/ernie/src/trainers/pretraining_trainer.py @@ -1260,7 +1260,7 @@ def _maybe_log_save_evaluate( ) logs["learning_rate"] = float(self._get_learning_rate()) logs["global_step"] = int(self.state.global_step) - + logs["loss_md5"] = paddle.to_tensor(logs["loss"])._md5sum() divisor = 2**30 current_device = framework._current_expected_place_() diff --git a/examples/pre-training/models/ernie/configuration.py b/examples/pre-training/models/ernie/configuration.py index a8cedf29b..90a41b1ee 100644 --- a/examples/pre-training/models/ernie/configuration.py +++ b/examples/pre-training/models/ernie/configuration.py @@ -149,6 +149,7 @@ def __init__( global_aux_loss=False, moe_dropout_prob=0.0, moe_group="world", + moe_group_name="world", num_experts_per_tok: int = 8, moe_intermediate_size: Union[int, list] = 0, moe_num_shared_experts: int = 0, @@ -356,6 +357,7 @@ def update_nested_dict(default_dict, update_dict): self.moe_layer_interval = moe_layer_interval self.moe_dropout_prob = moe_dropout_prob self.moe_group = moe_group + self.moe_group_name = moe_group_name self.num_experts_per_tok = num_experts_per_tok self.moe_num_shared_experts = moe_num_shared_experts self.moe_num_dense_experts = moe_num_dense_experts @@ -395,7 +397,6 @@ def update_nested_dict(default_dict, update_dict): self.use_linear_residual_norm_recompute = use_linear_residual_norm_recompute self.use_rms_qkv_recompute = use_rms_qkv_recompute - assert aux_loss_type in ["", "default", "seq_aux_loss", "switch_aux_loss"] self.aux_loss_type = aux_loss_type diff --git a/examples/pre-training/models/ernie/modeling.py b/examples/pre-training/models/ernie/modeling.py index 79993d1a2..e7d8a049a 100644 --- a/examples/pre-training/models/ernie/modeling.py +++ b/examples/pre-training/models/ernie/modeling.py @@ -61,6 +61,9 @@ ) from paddleformers.transformers.model_utils import PretrainedModel, register_base_model from paddleformers.utils.tools import get_env_device +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + build_sharded_state_dict, +) from .configuration import ErnieMoEConfig @@ -72,7 +75,9 @@ try: from paddle.nn.functional.flash_attention import flash_attention - logger.warning("Use flash attention in scaled-dot-product. Attention mask is deprecated") + logger.warning( + "Use flash attention in scaled-dot-product. Attention mask is deprecated" + ) except (ImportError, ModuleNotFoundError): flash_attention = None @@ -153,7 +158,9 @@ def gqa_qkv_split_func( q_list = paddle.split(q, tensor_parallel_degree, axis=-1) k_list = paddle.split(k, tensor_parallel_degree, axis=-1) v_list = paddle.split(v, tensor_parallel_degree, axis=-1) - ret = [paddle.concat([q, k, v], axis=-1) for q, k, v in zip(q_list, k_list, v_list)] + ret = [ + paddle.concat([q, k, v], axis=-1) for q, k, v in zip(q_list, k_list, v_list) + ] return ret else: q = paddle.split(q, tensor_parallel_degree, axis=-1)[tensor_parallel_rank] @@ -205,7 +212,9 @@ def parallel_matmul( logits += bias else: if fuse_linear: - logits = paddle.incubate.nn.functional.fused_linear(input_parallel, y, bias) + logits = paddle.incubate.nn.functional.fused_linear( + input_parallel, y, bias + ) else: logits = F.linear(input_parallel, y, bias) @@ -216,7 +225,9 @@ def parallel_matmul( else: if fuse_linear: - logits = paddle.incubate.nn.functional.fused_linear(x, y, bias, transpose_weight=transpose_y) + logits = paddle.incubate.nn.functional.fused_linear( + x, y, bias, transpose_weight=transpose_y + ) else: logits = paddle.matmul(x, y, transpose_y=transpose_y) if bias is not None: @@ -224,7 +235,9 @@ def parallel_matmul( return logits -def calc_lm_head_logits(config, hidden_states, weight, bias, tensor_parallel_output=None, training=True): +def calc_lm_head_logits( + config, hidden_states, weight, bias, tensor_parallel_output=None, training=True +): if config.sequence_parallel: if config.use_sparse_head_and_loss_fn: pass @@ -233,10 +246,16 @@ def calc_lm_head_logits(config, hidden_states, weight, bias, tensor_parallel_out if lm_head_use_gather: hidden_states = GatherOp.apply(hidden_states) if not config.using_dynamic_sequence_length: - hidden_states = hidden_states.reshape([-1, config.seqlen, hidden_states.shape[-1]]) + hidden_states = hidden_states.reshape( + [-1, config.seqlen, hidden_states.shape[-1]] + ) else: - assert config.micro_batch_size, "micro_batch_size should be set when using dygramic sequence length." - hidden_states = hidden_states.reshape([config.micro_batch_size, -1, hidden_states.shape[-1]]) + assert ( + config.micro_batch_size + ), "micro_batch_size should be set when using dygramic sequence length." + hidden_states = hidden_states.reshape( + [config.micro_batch_size, -1, hidden_states.shape[-1]] + ) if tensor_parallel_output is None: tensor_parallel_output = config.tensor_parallel_output @@ -277,7 +296,9 @@ def masked_fill(x, mask, value): return paddle.where(mask, y, x) -def mem_eff_attn(query, key, value, pack_offset, drop_prob=0.0, dtype=paddle.bfloat16, training=True): +def mem_eff_attn( + query, key, value, pack_offset, drop_prob=0.0, dtype=paddle.bfloat16, training=True +): pack_offset = pack_offset.numpy() shape = pack_offset.shape assert len(shape) == 2, len(shape) @@ -300,7 +321,9 @@ def cast(x): return x.astype(dtype) if x.dtype != dtype else x if len(seqlens) == 1: - out, _ = flash_attention(query, key, value, drop_prob, causal=True, training=training) + out, _ = flash_attention( + query, key, value, drop_prob, causal=True, training=training + ) else: mask = BlockDiagonalCausalMask.from_seqlens(seqlens) out = memory_efficient_attention( @@ -326,7 +349,9 @@ def inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset): row_start_indices = np.repeat(cumsum_item[1:], record_lens) attn_mask_row_start_indices.append(row_start_indices[None, None, ...]) attn_mask_row_start_indices = np.concatenate(attn_mask_row_start_indices, axis=0) - return paddle.to_tensor(attn_mask_row_start_indices, dtype=paddle.int32), int(min_start_row) + return paddle.to_tensor(attn_mask_row_start_indices, dtype=paddle.int32), int( + min_start_row + ) def scaled_dot_product_attention( @@ -360,14 +385,20 @@ def scaled_dot_product_attention( can_use_fa = config.use_flash_attn and flash_attention is not None can_use_fa_sparse_mask = ( - config.use_mem_eff_attn and inbatch_pack_offset is not None and flashmask_attention is not None + config.use_mem_eff_attn + and inbatch_pack_offset is not None + and flashmask_attention is not None ) if not can_use_fa and not can_use_fa_sparse_mask: if query_states.shape[-2] != key_states.shape[-2]: - key_states = key_states.repeat_interleave(num_heads // num_key_value_heads, axis=-2) + key_states = key_states.repeat_interleave( + num_heads // num_key_value_heads, axis=-2 + ) if query_states.shape[-2] != value_states.shape[-2]: - value_states = value_states.repeat_interleave(num_heads // num_key_value_heads, axis=-2) + value_states = value_states.repeat_interleave( + num_heads // num_key_value_heads, axis=-2 + ) if can_use_fa: assert not (config.use_mem_eff_attn and inbatch_pack_offset is not None) @@ -384,7 +415,9 @@ def scaled_dot_product_attention( return attn_output, attn_weights else: - query_states = paddle.transpose(query_states, [0, 2, 1, 3]) / math.sqrt(head_dim) + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) / math.sqrt( + head_dim + ) key_states = paddle.transpose(key_states, [0, 2, 1, 3]) value_states = paddle.transpose(value_states, [0, 2, 1, 3]) @@ -408,14 +441,20 @@ def scaled_dot_product_attention( attn_weights = attention_mask + attn_weights attn_weights = paddle.maximum( attn_weights, - paddle.to_tensor(float(finfo(query_states.dtype).min), dtype=query_states.dtype), + paddle.to_tensor( + float(finfo(query_states.dtype).min), dtype=query_states.dtype + ), ) if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + attn_weights = F.softmax( + attn_weights, axis=-1, dtype="float32" + ).astype(query_states.dtype) else: - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype( + query_states.dtype + ) else: attn_weights = attn_weights.cast(paddle.float32) attention_mask = attention_mask.cast(paddle.float32) @@ -453,12 +492,18 @@ def _make_causal_mask(input_ids_shape, past_key_values_length, dtype): mask = paddle.full((target_length, target_length), float(finfo(dtype).min)) mask_cond = paddle.arange(mask.shape[-1]) - mask = masked_fill(mask, mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0) + mask = masked_fill( + mask, mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0 + ) if past_key_values_length > 0: - mask = paddle.concat([paddle.zeros([target_length, past_key_values_length]), mask], axis=-1) + mask = paddle.concat( + [paddle.zeros([target_length, past_key_values_length]), mask], axis=-1 + ) - return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + return mask[None, None, :, :].expand( + [batch_size, 1, target_length, target_length + past_key_values_length] + ) def _expand_mask(mask, dtype, tgt_length): @@ -470,10 +515,14 @@ def _expand_mask(mask, dtype, tgt_length): batch_size, src_length = mask.shape[0], mask.shape[-1] tgt_length = tgt_length if tgt_length is not None else src_length - expanded_mask = mask[:, None, None, :].expand([batch_size, 1, tgt_length, src_length]) + expanded_mask = mask[:, None, None, :].expand( + [batch_size, 1, tgt_length, src_length] + ) inverted_mask = 1.0 - expanded_mask - return masked_fill(inverted_mask, inverted_mask.cast("bool"), float(finfo(dtype).min)) + return masked_fill( + inverted_mask, inverted_mask.cast("bool"), float(finfo(dtype).min) + ) class FusedDropoutImpl(nn.Layer): @@ -510,14 +559,20 @@ def __init__(self, config): def forward(self, hidden_states): if self.config.fuse_rms_norm: - return fused_rms_norm_ext(hidden_states, self.weight, self.variance_epsilon)[0] + return fused_rms_norm_ext( + hidden_states, self.weight, self.variance_epsilon + )[0] if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) - hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + hidden_states = ( + paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + ) else: variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) - hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + hidden_states = ( + paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + ) if self.weight.dtype in [paddle.float16, paddle.bfloat16]: hidden_states = paddle.cast(hidden_states, self.weight.dtype) @@ -529,7 +584,9 @@ def __init__(self, dim, max_position_embeddings=4096, base=10000): super().__init__() self.base = base self.max_position_embeddings = max_position_embeddings - inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / dim)) + inv_freq = 1.0 / ( + base ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / dim) + ) t = paddle.arange(max_position_embeddings, dtype="float32") freqs = paddle.einsum("i,j->ij", t, inv_freq.cast("float32")) @@ -567,8 +624,12 @@ def apply_rotary_pos_emb(cls, q, k, cos, sin, offset: int = 0, position_ids=None cos = cos[:, offset : q.shape[1] + offset, None, :] sin = sin[:, offset : q.shape[1] + offset, None, :] - q_embed = paddle.add(paddle.multiply(q, cos), paddle.multiply(cls.rotate_half(q), sin)) - k_embed = paddle.add(paddle.multiply(k, cos), paddle.multiply(cls.rotate_half(k), sin)) + q_embed = paddle.add( + paddle.multiply(q, cos), paddle.multiply(cls.rotate_half(q), sin) + ) + k_embed = paddle.add( + paddle.multiply(k, cos), paddle.multiply(cls.rotate_half(k), sin) + ) q_embed = q_embed.astype(q.dtype) k_embed = k_embed.astype(k.dtype) return q_embed, k_embed @@ -592,8 +653,12 @@ def forward(self, seq_length, position_ids=None): else: position_ids = position_ids / self.compression_ratio seq_length = position_ids.shape[-1] - sinusoid_inp = position_ids.unsqueeze(-1).astype("float32") * indices.unsqueeze(0) - pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) + sinusoid_inp = position_ids.unsqueeze(-1).astype( + "float32" + ) * indices.unsqueeze(0) + pos_emb = paddle.concat( + [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1 + ) pos_emb = paddle.reshape(pos_emb, (-1, 1, seq_length, self.head_dim)) pos_emb.stop_gradient = True return pos_emb @@ -642,7 +707,9 @@ def apply_rotary_3d(self, rp, q, k, position_ids): :, 1 : self.head_dim // 2 - self.freq_allocation : 2, ] - sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape(sin_h.shape[:-1] + [sin_h.shape[-1] * 2]) + sin_hw = paddle.stack([sin_h, sin_w], axis=-1).reshape( + sin_h.shape[:-1] + [sin_h.shape[-1] * 2] + ) sin_thw = paddle.concat([sin_hw, sin_t], axis=-1) cos_t = cos[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] @@ -658,7 +725,9 @@ def apply_rotary_3d(self, rp, q, k, position_ids): :, 1 : self.head_dim // 2 - self.freq_allocation : 2, ] - cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape(cos_h.shape[:-1] + [cos_h.shape[-1] * 2]) + cos_hw = paddle.stack([cos_h, cos_w], axis=-1).reshape( + cos_h.shape[:-1] + [cos_h.shape[-1] * 2] + ) cos_thw = paddle.concat([cos_hw, cos_t], axis=-1) sin_pos = paddle.reshape( @@ -690,12 +759,18 @@ def apply_rotary_3d(self, rp, q, k, position_ids): def forward_single(self, position_ids): batch_size, seq_length = position_ids.shape[:2] - rope_emb = paddle.zeros((2, batch_size, seq_length, 1, self.head_dim), dtype="float32") - inv_freq = self.base ** (-paddle.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim) + rope_emb = paddle.zeros( + (2, batch_size, seq_length, 1, self.head_dim), dtype="float32" + ) + inv_freq = self.base ** ( + -paddle.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim + ) position_ids = position_ids.cast("float32") position_ids = position_ids / self.compression_ratio freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) - emb = paddle.stack([freqs, freqs], axis=-1).reshape((batch_size, seq_length, self.head_dim)) + emb = paddle.stack([freqs, freqs], axis=-1).reshape( + (batch_size, seq_length, self.head_dim) + ) emb = paddle.unsqueeze(emb, 2) rope_emb[0] = paddle.cos(emb) @@ -720,16 +795,29 @@ def __init__(self, config): self.fuse_ffn = config.fuse_attn_ffn if config.tensor_parallel_degree > 1: - ColumnLN = ColumnSequenceParallelLinear if config.sequence_parallel else ColumnParallelLinear - RowLN = RowSequenceParallelLinear if config.sequence_parallel else RowParallelLinear + ColumnLN = ( + ColumnSequenceParallelLinear + if config.sequence_parallel + else ColumnParallelLinear + ) + RowLN = ( + RowSequenceParallelLinear + if config.sequence_parallel + else RowParallelLinear + ) column_ln_configs = ( - {"use_rr": config.use_recompute and config.skip_recompute_ops.get("mlp_column_ln", False)} + { + "use_rr": config.use_recompute + and config.skip_recompute_ops.get("mlp_column_ln", False) + } if config.sequence_parallel and get_env_device() == "gpu" else {} ) if config.sequence_parallel and get_env_device() == "gpu": - column_ln_configs["use_tpsp_comm_overlap"] = config.use_tpsp_comm_overlap + column_ln_configs["use_tpsp_comm_overlap"] = ( + config.use_tpsp_comm_overlap + ) if config.fuse_attn_ffn: self.up_gate_proj = ColumnLN( self.hidden_size, @@ -757,7 +845,9 @@ def __init__(self, config): **column_ln_configs, ) else: - LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + LinearFN = ( + paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + ) if config.fuse_attn_ffn: self.up_gate_proj = LinearFN( self.hidden_size, @@ -765,12 +855,19 @@ def __init__(self, config): bias_attr=config.use_bias, ) else: - self.gate_proj = LinearFN(self.hidden_size, self.intermediate_size, bias_attr=config.use_bias) - self.up_proj = LinearFN(self.hidden_size, self.intermediate_size, bias_attr=config.use_bias) + self.gate_proj = LinearFN( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) + self.up_proj = LinearFN( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) if config.tensor_parallel_degree > 1: row_ln_configs = ( - {"use_rr": config.use_recompute and config.skip_recompute_ops.get("mlp_row_ln", False)} + { + "use_rr": config.use_recompute + and config.skip_recompute_ops.get("mlp_row_ln", False) + } if config.sequence_parallel and get_env_device() == "gpu" else {} ) @@ -785,8 +882,12 @@ def __init__(self, config): **row_ln_configs, ) else: - LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear - self.down_proj = LinearFN(self.intermediate_size, self.hidden_size, bias_attr=config.use_bias) + LinearFN = ( + paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + ) + self.down_proj = LinearFN( + self.intermediate_size, self.hidden_size, bias_attr=config.use_bias + ) self.fuse_swiglu = config.fuse_swiglu if self.fuse_swiglu: @@ -799,7 +900,9 @@ def forward(self, x): and self.config.use_fp8_mlp and not self.config.use_bias ): - return MemEfficientFp8FusedMlpFunc.apply(x, self.up_gate_proj.weight, self.down_proj.weight) + return MemEfficientFp8FusedMlpFunc.apply( + x, self.up_gate_proj.weight, self.down_proj.weight + ) if self.fuse_swiglu: if self.fuse_ffn: @@ -844,15 +947,24 @@ def __init__(self, config, layer_idx=0): self.fuse_attn = config.fuse_attn_ffn self.use_recompute_attn = config.use_recompute_attn logger.info(f"using recompute attn={self.use_recompute_attn}") - self.is_gqa = config.num_key_value_heads is not None and config.num_key_value_heads != self.num_heads + self.is_gqa = ( + config.num_key_value_heads is not None + and config.num_key_value_heads != self.num_heads + ) if config.fuse_rope: assert fused_rope is not None, "fused_rope is not supported" self.fuse_rope = config.fuse_rope self.rope_3d = config.rope_3d if self.rope_3d: - assert not self.fuse_rope, "does not support fuse rope when rope_3d is on for now." - assert not config.rope_reorder, "does not support rope_reorder when rope_3d is on for now." - assert config.freq_allocation is not None, "freq_allocation must be provided if rope_3d is on." + assert ( + not self.fuse_rope + ), "does not support fuse rope when rope_3d is on for now." + assert ( + not config.rope_reorder + ), "does not support rope_reorder when rope_3d is on for now." + assert ( + config.freq_allocation is not None + ), "freq_allocation must be provided if rope_3d is on." if config.tensor_parallel_degree > 1: assert ( @@ -863,9 +975,13 @@ def __init__(self, config, layer_idx=0): assert ( self.num_key_value_heads % config.tensor_parallel_degree == 0 ), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" - self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + self.num_key_value_heads = ( + self.num_key_value_heads // config.tensor_parallel_degree + ) if self.is_gqa: - logger.info(f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}") + logger.info( + f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}" + ) assert ( self.num_heads % self.num_key_value_heads == 0 ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" @@ -875,15 +991,28 @@ def __init__(self, config, layer_idx=0): q_hidden_size = kv_hidden_size = self.head_dim * config.num_attention_heads if config.tensor_parallel_degree > 1: - ColumnLN = ColumnSequenceParallelLinear if config.sequence_parallel else ColumnParallelLinear - RowLN = RowSequenceParallelLinear if config.sequence_parallel else RowParallelLinear + ColumnLN = ( + ColumnSequenceParallelLinear + if config.sequence_parallel + else ColumnParallelLinear + ) + RowLN = ( + RowSequenceParallelLinear + if config.sequence_parallel + else RowParallelLinear + ) column_ln_configs = ( - {"use_rr": config.use_recompute and config.skip_recompute_ops.get("attention_column_ln", False)} + { + "use_rr": config.use_recompute + and config.skip_recompute_ops.get("attention_column_ln", False) + } if config.sequence_parallel and get_env_device() == "gpu" else {} ) if config.sequence_parallel and get_env_device() == "gpu": - column_ln_configs["use_tpsp_comm_overlap"] = config.use_tpsp_comm_overlap + column_ln_configs["use_tpsp_comm_overlap"] = ( + config.use_tpsp_comm_overlap + ) if config.fuse_attn_ffn: self.qkv_proj = ColumnLN( @@ -920,7 +1049,9 @@ def __init__(self, config, layer_idx=0): **column_ln_configs, ) else: - LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + LinearFN = ( + paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + ) if config.fuse_attn_ffn: self.qkv_proj = LinearFN( self.hidden_size, @@ -946,7 +1077,10 @@ def __init__(self, config, layer_idx=0): if config.tensor_parallel_degree > 1: row_ln_configs = ( - {"use_rr": config.use_recompute and config.skip_recompute_ops.get("attention_row_ln", False)} + { + "use_rr": config.use_recompute + and config.skip_recompute_ops.get("attention_row_ln", False) + } if config.sequence_parallel and get_env_device() == "gpu" else {} ) @@ -962,7 +1096,9 @@ def __init__(self, config, layer_idx=0): **row_ln_configs, ) else: - LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + LinearFN = ( + paddle.incubate.nn.FusedLinear if config.fuse_linear else NativeLinear + ) self.o_proj = LinearFN( q_hidden_size, self.hidden_size, @@ -1004,7 +1140,11 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: if self.config.sequence_parallel: if not self.config.using_dynamic_sequence_length: - bsz = hidden_states.shape[0] * self.config.tensor_parallel_degree // self.config.seqlen + bsz = ( + hidden_states.shape[0] + * self.config.tensor_parallel_degree + // self.config.seqlen + ) q_len = self.config.seqlen else: assert ( @@ -1012,7 +1152,9 @@ def forward( ), "micro_batch_size should be set when using dygramic sequence length." bsz = self.config.micro_batch_size - q_len = hidden_states.shape[0] * self.config.tensor_parallel_degree // bsz + q_len = ( + hidden_states.shape[0] * self.config.tensor_parallel_degree // bsz + ) else: bsz, q_len, _ = hidden_states.shape query_states = key_states = value_states = mix_layer = None @@ -1030,9 +1172,13 @@ def forward( ) mix_layer = None else: - mix_layer = mix_layer.reshape([bsz, q_len, self.num_heads, 3 * self.head_dim]) + mix_layer = mix_layer.reshape( + [bsz, q_len, self.num_heads, 3 * self.head_dim] + ) else: - query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + query_states = self.q_proj(hidden_states).reshape( + shape=[bsz, q_len, self.num_heads, self.head_dim] + ) key_states = self.k_proj(hidden_states).reshape( shape=[ bsz, @@ -1109,7 +1255,9 @@ def rope_attn( if self.rope_3d: assert position_ids is not None, "rope3d requires pos-id" - kv_seq_len = key_states.shape[-3] if not self.rope_3d else position_ids.max() + 1 + kv_seq_len = ( + key_states.shape[-3] if not self.rope_3d else position_ids.max() + 1 + ) offset = 0 if past_key_value is not None: if not self.rope_3d: @@ -1134,10 +1282,14 @@ def rope_attn( else: if offset > 0 or position_ids is not None or not self.fuse_rope: if not self.rope_3d: - cos_sin = self.rotary_emb(kv_seq_len, position_ids).transpose([0, 2, 1, 3]) + cos_sin = self.rotary_emb(kv_seq_len, position_ids).transpose( + [0, 2, 1, 3] + ) if offset > 0 and position_ids is None: cos_sin = cos_sin[:, offset:] - query_states, key_states = self.rotary_emb.apply_rotary(cos_sin, query_states, key_states) + query_states, key_states = self.rotary_emb.apply_rotary( + cos_sin, query_states, key_states + ) else: cos_sin = self.rotary_emb(kv_seq_len).transpose([0, 2, 1, 3]) @@ -1152,8 +1304,12 @@ def rope_attn( bsz, q_len, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape if num_heads != num_key_value_heads: - query_states, _, _ = fused_rope(query_states, None, None, rotary_emb_base=self.config.rope_theta) - key_states, _, _ = fused_rope(key_states, None, None, rotary_emb_base=self.config.rope_theta) + query_states, _, _ = fused_rope( + query_states, None, None, rotary_emb_base=self.config.rope_theta + ) + key_states, _, _ = fused_rope( + key_states, None, None, rotary_emb_base=self.config.rope_theta + ) else: query_states, key_states, _ = fused_rope( query_states, @@ -1199,8 +1355,12 @@ def __init__(self, config, layer_idx=0): self.input_layernorm = Norm(config) self.post_attention_layernorm = Norm(config) - self.residual_add1 = FusedDropoutImpl(config.hidden_dropout_prob, mode="upscale_in_train") - self.residual_add2 = FusedDropoutImpl(config.hidden_dropout_prob, mode="upscale_in_train") + self.residual_add1 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) + self.residual_add2 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) self.config = config def forward( @@ -1226,8 +1386,13 @@ def forward( inbatch_pack_offset=inbatch_pack_offset, ) - if self.config.tensor_parallel_degree > 1 and self.config.hidden_dropout_prob > 0.0: - current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + if ( + self.config.tensor_parallel_degree > 1 + and self.config.hidden_dropout_prob > 0.0 + ): + current_seed = ( + "local_seed" if self.config.sequence_parallel else "global_seed" + ) with get_rng_state_tracker().rng_state(current_seed): hidden_states = self.residual_add1(hidden_states, residual) else: @@ -1237,8 +1402,13 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - if self.config.tensor_parallel_degree > 1 and self.config.hidden_dropout_prob > 0.0: - current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + if ( + self.config.tensor_parallel_degree > 1 + and self.config.hidden_dropout_prob > 0.0 + ): + current_seed = ( + "local_seed" if self.config.sequence_parallel else "global_seed" + ) with get_rng_state_tracker().rng_state(current_seed): hidden_states = self.residual_add2(hidden_states, residual) else: @@ -1270,7 +1440,9 @@ def _get_name_mappings(cls, config: ErnieMoEConfig) -> StateDictNameMapping: ["norm.weight"], ] for layer_index in range( - config.num_hidden_layers if not config.remove_tail_layer else config.num_hidden_layers - 1 + config.num_hidden_layers + if not config.remove_tail_layer + else config.num_hidden_layers - 1 ): if config.fuse_attn_ffn: layer_mappings = [ @@ -1332,7 +1504,10 @@ def _get_name_mappings(cls, config: ErnieMoEConfig) -> StateDictNameMapping: mapping[1] = "ernie." + mapping[1] model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) - mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + mappings = [ + StateDictNameMapping(*mapping, index=index) + for index, mapping in enumerate(model_mappings) + ] return mappings @classmethod @@ -1347,7 +1522,10 @@ def _get_tensor_parallel_mappings(cls, config, is_split=True): num_attention_heads=config.num_attention_heads, ) - if config.num_key_value_heads is not None and config.num_key_value_heads != config.num_attention_heads: + if ( + config.num_key_value_heads is not None + and config.num_key_value_heads != config.num_attention_heads + ): if is_split: qkv_fn = partial( gqa_qkv_split_func, @@ -1372,8 +1550,12 @@ def get_tensor_parallel_split_mappings(num_layers): if config.fuse_attn_ffn: base_actions = { "layers.0.self_attn.qkv_proj.weight": qkv_fn, - "layers.0.mlp.up_gate_proj.weight": partial(fn, is_column=True, is_naive_2fuse=True), - "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + "layers.0.mlp.up_gate_proj.weight": partial( + fn, is_column=True, is_naive_2fuse=True + ), + "lm_head.weight": partial( + fn, is_column=not config.tie_word_embeddings + ), "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), @@ -1382,7 +1564,9 @@ def get_tensor_parallel_split_mappings(num_layers): base_actions.update( { "layers.0.self_attn.qkv_proj.bias": qkv_fn, - "layers.0.mlp.up_gate_proj.bias": partial(fn, is_column=True, is_naive_2fuse=True), + "layers.0.mlp.up_gate_proj.bias": partial( + fn, is_column=True, is_naive_2fuse=True + ), "lm_head.bias": partial(fn, is_column=True), } ) @@ -1393,7 +1577,9 @@ def get_tensor_parallel_split_mappings(num_layers): "layers.0.self_attn.v_proj.weight": partial(fn, is_column=True), "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), - "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + "lm_head.weight": partial( + fn, is_column=not config.tie_word_embeddings + ), "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), @@ -1401,9 +1587,15 @@ def get_tensor_parallel_split_mappings(num_layers): if config.use_bias: base_actions.update( { - "layers.0.self_attn.q_proj.bias": partial(fn, is_column=True), - "layers.0.self_attn.k_proj.bias": partial(fn, is_column=True), - "layers.0.self_attn.v_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.q_proj.bias": partial( + fn, is_column=True + ), + "layers.0.self_attn.k_proj.bias": partial( + fn, is_column=True + ), + "layers.0.self_attn.v_proj.bias": partial( + fn, is_column=True + ), "layers.0.mlp.gate_proj.bias": partial(fn, is_column=True), "layers.0.mlp.up_proj.bias": partial(fn, is_column=True), "lm_head.bias": partial(fn, is_column=True), @@ -1418,7 +1610,9 @@ def get_tensor_parallel_split_mappings(num_layers): return final_actions mappings = get_tensor_parallel_split_mappings( - config.num_hidden_layers if not config.remove_tail_layer else config.num_hidden_layers - 1 + config.num_hidden_layers + if not config.remove_tail_layer + else config.num_hidden_layers - 1 ) return mappings @@ -1448,7 +1642,9 @@ def _init_weights(self, layer): dtype = paddle.get_default_dtype() paddle.set_default_dtype("float32") layer.weight.set_value( - paddle.randn(layer.weight.shape, dtype=dtype).scale(self.config.initializer_range) + paddle.randn(layer.weight.shape, dtype=dtype).scale( + self.config.initializer_range + ) ) paddle.set_default_dtype(dtype) logger.info( @@ -1459,7 +1655,9 @@ def _init_weights(self, layer): elif isinstance(layer, RotaryEmbedding): head_dim = self.config.hidden_size // self.config.num_attention_heads - inv_freq = 1.0 / (layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) + inv_freq = 1.0 / ( + layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) + ) t = np.arange(layer.max_position_embeddings, dtype="float32") freqs = np.einsum("i,j->ij", t, inv_freq) @@ -1494,7 +1692,9 @@ def __init__(self, config: ErnieMoEConfig): layers_list = [ ErnieDecoderLayer(config, layer_idx) for layer_idx in range( - config.num_hidden_layers - 1 if config.remove_tail_layer else config.num_hidden_layers + config.num_hidden_layers - 1 + if config.remove_tail_layer + else config.num_hidden_layers ) ] @@ -1512,7 +1712,9 @@ def set_input_embeddings(self, value): self.embed_tokens = value @classmethod - def _prepare_decoder_attention_mask(cls, attention_mask, input_shape, past_key_values_length, dtype): + def _prepare_decoder_attention_mask( + cls, attention_mask, input_shape, past_key_values_length, dtype + ): combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( @@ -1520,9 +1722,13 @@ def _prepare_decoder_attention_mask(cls, attention_mask, input_shape, past_key_v ) if attention_mask is not None: - expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + expanded_attn_mask = _expand_mask( + attention_mask, dtype, tgt_length=input_shape[-1] + ) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask ) combined_attention_mask = paddle.maximum( combined_attention_mask.astype(dtype), @@ -1576,22 +1782,34 @@ def forward( inbatch_pack_offset=None, **kwargs, ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) if past_key_values is None: past_key_values = tuple([None] * len(self.layers)) @@ -1611,7 +1829,9 @@ def forward( inputs_embeds = ScatterOp.apply(inputs_embeds) can_use_fa = self.config.use_flash_attn and flash_attention is not None - can_mem_eff_attn = self.config.use_mem_eff_attn and inbatch_pack_offset is not None + can_mem_eff_attn = ( + self.config.use_mem_eff_attn and inbatch_pack_offset is not None + ) if can_use_fa or can_mem_eff_attn: if attention_mask is not None: attention_mask = None @@ -1621,7 +1841,9 @@ def forward( f"attention_mask is not None = {attention_mask is not None}" ) elif attention_mask is None: - attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + attention_mask = paddle.ones( + (batch_size, seq_length_with_past), dtype=paddle.bool + ) if attention_mask is not None: attention_mask = self._prepare_decoder_attention_mask( @@ -1640,7 +1862,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) has_gradient = not hidden_states.stop_gradient if self.config.use_recompute and has_gradient: @@ -1686,7 +1910,11 @@ def forward( next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1725,7 +1953,9 @@ def forward( if ctx.tensor_parallel_degree > 1: ctx.mp_group = ( - fleet.get_hybrid_communicate_group().get_model_parallel_group() if mp_group is None else mp_group + fleet.get_hybrid_communicate_group().get_model_parallel_group() + if mp_group is None + else mp_group ) ctx.rank = ctx.mp_group.rank ctx.world_size = ctx.mp_group.nranks @@ -1758,7 +1988,9 @@ def forward( [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]], dtype=hidden_states.dtype, ) - labels_recv = paddle.empty([ctx.num_tokens_per_rank[idx]], dtype=labels.dtype) + labels_recv = paddle.empty( + [ctx.num_tokens_per_rank[idx]], dtype=labels.dtype + ) if ctx.tensor_parallel_degree > 1: dist.stream.broadcast( @@ -1766,7 +1998,9 @@ def forward( src=ctx.mp_group.ranks[idx], group=ctx.mp_group, ) - dist.stream.broadcast(labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group) + dist.stream.broadcast( + labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group + ) seq_len = hidden_states_recv.shape[0] num_chunk = (seq_len + ctx.seq_chunk_size - 1) // ctx.seq_chunk_size @@ -1821,7 +2055,9 @@ def backward(ctx, loss_all_grad, labels_all_grad): hidden_states, weight, bias, labels = ctx.saved_tensor() - loss_all_grad_list = paddle.split(loss_all_grad, ctx.loss_concat_sections, axis=0) + loss_all_grad_list = paddle.split( + loss_all_grad, ctx.loss_concat_sections, axis=0 + ) def detach_variable(inp): if inp is None: @@ -1857,14 +2093,18 @@ def detach_variable(inp): [ctx.num_tokens_per_rank[idx], hidden_states.shape[-1]], dtype=hidden_states.dtype, ) - labels_recv = paddle.empty([ctx.num_tokens_per_rank[idx]], dtype=labels.dtype) + labels_recv = paddle.empty( + [ctx.num_tokens_per_rank[idx]], dtype=labels.dtype + ) if ctx.tensor_parallel_degree > 1: dist.stream.broadcast( hidden_states_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group, ) - dist.stream.broadcast(labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group) + dist.stream.broadcast( + labels_recv, src=ctx.mp_group.ranks[idx], group=ctx.mp_group + ) hidden_states_recv.stop_gradient = False seq_len = hidden_states_recv.shape[0] @@ -1873,7 +2113,9 @@ def detach_variable(inp): for chunk_idx in range(num_chunk): start = chunk_idx * ctx.seq_chunk_size end = min(start + ctx.seq_chunk_size, seq_len) - hidden_states_chunk = hidden_states_recv.slice(axes=[0], starts=[start], ends=[end]) + hidden_states_chunk = hidden_states_recv.slice( + axes=[0], starts=[start], ends=[end] + ) labels_chunk = labels_recv._slice(start, end) loss_grad_chunk = loss_all_grad_list[idx]._slice(start, end) @@ -1897,10 +2139,12 @@ def detach_variable(inp): ignore_index=ctx.ignore_index, ) else: - loss_chunk = paddle.nn.functional.softmax_with_cross_entropy( - logits.cast("float32"), - labels_chunk.unsqueeze(-1), - ignore_index=ctx.ignore_index, + loss_chunk = ( + paddle.nn.functional.softmax_with_cross_entropy( + logits.cast("float32"), + labels_chunk.unsqueeze(-1), + ignore_index=ctx.ignore_index, + ) ) with paddle.amp.auto_cast(enable=False): @@ -1915,7 +2159,9 @@ def detach_variable(inp): if idx == ctx.rank: hidden_states_grad = hidden_states_recv.grad - hidden_states_grad = hidden_states_grad.reshape(ctx.hidden_states_shape) + hidden_states_grad = hidden_states_grad.reshape( + ctx.hidden_states_shape + ) if weight_main_grad is not None: weight_main_grad = weight_main_grad.astype(weight.dtype) @@ -1942,7 +2188,9 @@ def __init__(self, config, return_tuple=True): self.ignored_index = getattr(config, "ignored_index", -100) self.config = config self.return_tuple = return_tuple - self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + self.enable_parallel_cross_entropy = ( + config.tensor_parallel_degree > 1 and config.tensor_parallel_output + ) if self.enable_parallel_cross_entropy: self.loss_func = fleet.meta_parallel.ParallelCrossEntropy() @@ -1958,19 +2206,27 @@ def forward(self, prediction_scores, masked_lm_labels): hidden_states, outlinear_weight, outlinear_bias = prediction_scores if self.config.sequence_parallel: - masked_lm_labels, sparse_label_idx = sequence_parallel_sparse_mask_labels( - masked_lm_labels, self.ignored_index + masked_lm_labels, sparse_label_idx = ( + sequence_parallel_sparse_mask_labels( + masked_lm_labels, self.ignored_index + ) ) sparse_label_idx = sparse_label_idx.reshape([-1, 1]) hidden_states = paddle.gather(hidden_states, sparse_label_idx, axis=0) hidden_states = AllGatherVarlenOp.apply(hidden_states) else: masked_lm_labels = masked_lm_labels.flatten() - sparse_label_idx = paddle.nonzero(masked_lm_labels != self.ignored_index).flatten() - masked_lm_labels = paddle.take_along_axis(masked_lm_labels, sparse_label_idx, axis=0) + sparse_label_idx = paddle.nonzero( + masked_lm_labels != self.ignored_index + ).flatten() + masked_lm_labels = paddle.take_along_axis( + masked_lm_labels, sparse_label_idx, axis=0 + ) hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) - hidden_states = paddle.take_along_axis(hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0) + hidden_states = paddle.take_along_axis( + hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0 + ) if self.config.use_recompute_loss_fn: offload_kwargs = {} @@ -1995,9 +2251,13 @@ def forward(self, prediction_scores, masked_lm_labels): res = self.forward_impl(logits, masked_lm_labels) elif self.config.use_recompute_loss_fn: if self.config.use_fused_head_loss_fn: - res = self.forward_impl_with_fused_head_loss_fn(masked_lm_labels, *prediction_scores) + res = self.forward_impl_with_fused_head_loss_fn( + masked_lm_labels, *prediction_scores + ) else: - assert isinstance(prediction_scores, tuple) and len(prediction_scores) in [3, 4], prediction_scores + assert isinstance(prediction_scores, tuple) and len( + prediction_scores + ) in [3, 4], prediction_scores res = recompute( self.forward_impl_with_calc_logits, masked_lm_labels, @@ -2008,7 +2268,9 @@ def forward(self, prediction_scores, masked_lm_labels): return res - def forward_impl_with_fused_head_loss_fn(self, masked_lm_labels, hidden_states, outlinear_weight, outlinear_bias): + def forward_impl_with_fused_head_loss_fn( + self, masked_lm_labels, hidden_states, outlinear_weight, outlinear_bias + ): masked_lm_labels.stop_gradient = True masked_lm_loss, masked_lm_labels_all = FusedHeadParallelCrossEntropy.apply( hidden_states, @@ -2024,13 +2286,17 @@ def forward_impl_with_fused_head_loss_fn(self, masked_lm_labels, hidden_states, ) lossmask = masked_lm_labels_all != self.ignored_index if (~lossmask).all(): - logger.warning(f"encounter empty span when calculate loss, ignored_index={self.ignored_index}") + logger.warning( + f"encounter empty span when calculate loss, ignored_index={self.ignored_index}" + ) loss = paddle.mean(masked_lm_loss) * 0.0 loss_sum = masked_lm_loss.sum().detach() else: lossmask = lossmask.reshape([-1]).cast(paddle.float32) - masked_lm_loss = paddle.sum(masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask) + masked_lm_loss = paddle.sum( + masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask + ) loss = masked_lm_loss / lossmask.sum() if self.token_balance_loss: _loss = masked_lm_loss / self.config.token_balance_seqlen @@ -2043,7 +2309,9 @@ def forward_impl_with_fused_head_loss_fn(self, masked_lm_labels, hidden_states, return loss_sum return loss, loss_sum - def forward_impl_with_calc_logits(self, masked_lm_labels, hidden_states, outlinear_weight, outlinear_bias): + def forward_impl_with_calc_logits( + self, masked_lm_labels, hidden_states, outlinear_weight, outlinear_bias + ): logits = calc_lm_head_logits( self.config, @@ -2057,7 +2325,9 @@ def forward_impl_with_calc_logits(self, masked_lm_labels, hidden_states, outline def loss_impl(self, prediction_scores, masked_lm_labels): prediction_scores = prediction_scores.cast("float32") - masked_lm_loss = self.loss_func(prediction_scores, masked_lm_labels.unsqueeze(-1)) + masked_lm_loss = self.loss_func( + prediction_scores, masked_lm_labels.unsqueeze(-1) + ) return masked_lm_loss @@ -2070,9 +2340,9 @@ def forward_impl(self, prediction_scores, masked_lm_labels): with paddle.amp.auto_cast(False): prediction_scores_dims = len(prediction_scores.shape) - if prediction_scores_dims == 2 and prediction_scores.shape[0] > self.config.get( - "loss_subbatch_seqlen", 32768 - ): + if prediction_scores_dims == 2 and prediction_scores.shape[ + 0 + ] > self.config.get("loss_subbatch_seqlen", 32768): sb_loss_func = subbatch( self.loss_impl, [0, 1], @@ -2081,9 +2351,9 @@ def forward_impl(self, prediction_scores, masked_lm_labels): 0, ) masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels) - elif prediction_scores_dims == 3 and prediction_scores.shape[1] > self.config.get( - "loss_subbatch_seqlen", 32768 - ): + elif prediction_scores_dims == 3 and prediction_scores.shape[ + 1 + ] > self.config.get("loss_subbatch_seqlen", 32768): sb_loss_func = subbatch( self.loss_impl, [0, 1], @@ -2097,12 +2367,16 @@ def forward_impl(self, prediction_scores, masked_lm_labels): lossmask = masked_lm_labels != self.ignored_index if (~lossmask).all(): - logger.warning(f"encounter empty span when calculate loss, ignored_index={self.ignored_index}") + logger.warning( + f"encounter empty span when calculate loss, ignored_index={self.ignored_index}" + ) loss = paddle.mean(masked_lm_loss) * 0.0 loss_sum = masked_lm_loss.sum().detach() else: lossmask = lossmask.reshape([-1]).cast(paddle.float32) - masked_lm_loss = paddle.sum(masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask) + masked_lm_loss = paddle.sum( + masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask + ) loss = masked_lm_loss / lossmask.sum() if self.token_balance_loss: _loss = masked_lm_loss / self.config.token_balance_seqlen @@ -2127,27 +2401,41 @@ def __init__(self, config): self.weight = self.create_parameter( shape=( - [vocab_size, config.hidden_size] if config.tie_word_embeddings else [config.hidden_size, vocab_size] + [vocab_size, config.hidden_size] + if config.tie_word_embeddings + else [config.hidden_size, vocab_size] ), dtype=paddle.get_default_dtype(), ) - logger.info(f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}") + logger.info( + f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}" + ) if config.weight_share_add_bias and config.use_bias: self.bias = self.create_parameter( shape=[vocab_size], dtype=paddle.get_default_dtype(), - attr=paddle.ParamAttr(initializer=paddle.nn.initializer.constant.Constant(0.0)), + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.constant.Constant(0.0) + ), ) else: self.bias = None - self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + self.weight.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) if config.weight_share_add_bias and config.use_bias: - self.bias.is_distributed = True if (vocab_size != config.vocab_size) else False + self.bias.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) if self.weight.is_distributed: self.weight.split_axis = 1 - if config.weight_share_add_bias and config.use_bias and self.bias.is_distributed: + if ( + config.weight_share_add_bias + and config.use_bias + and self.bias.is_distributed + ): self.bias.split_axis = 0 if self.config.use_recompute_loss_fn: @@ -2175,6 +2463,16 @@ def forward(self, hidden_states, tensor_parallel_output=None): training=self.training, ) + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.config.tie_word_embeddings else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict( + state_dict, {"weight": axis, "bias": 0}, structured_name_prefix + ) + class ErnieForCausalLM(ErniePretrainedModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] @@ -2196,7 +2494,9 @@ def __init__(self, config): ), f"sequence-parallel needs mp>1, got mp={config.tensor_parallel_degree}" new_initializer_range = math.sqrt(0.3333 / config.hidden_size) - logger.info(f"change initializer-range from {config.initializer_range} to {new_initializer_range}") + logger.info( + f"change initializer-range from {config.initializer_range} to {new_initializer_range}" + ) config.initializer_range = new_initializer_range self.config = config @@ -2283,16 +2583,30 @@ def prepare_inputs_for_generation( return model_inputs - def update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False): - if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + def update_model_kwargs_for_generation( + self, outputs, model_kwargs, is_encoder_decoder=False + ): + if ( + isinstance(outputs, tuple) + and len(outputs) > 1 + and not isinstance(outputs[1], paddle.Tensor) + ): model_kwargs["past_key_values"] = outputs[1] - if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + if ( + isinstance(outputs, CausalLMOutputWithCrossAttentions) + and "past_key_values" in outputs + ): model_kwargs["past_key_values"] = outputs.past_key_values - if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None: + if ( + "token_type_ids" in model_kwargs + and model_kwargs["token_type_ids"] is not None + ): token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = paddle.concat([token_type_ids, token_type_ids[:, -1:]], axis=-1) + model_kwargs["token_type_ids"] = paddle.concat( + [token_type_ids, token_type_ids[:, -1:]], axis=-1 + ) if not is_encoder_decoder: if "attention_mask" in model_kwargs: @@ -2306,10 +2620,14 @@ def update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_d ) if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None: role_ids = model_kwargs["role_ids"] - model_kwargs["role_ids"] = paddle.concat([role_ids, role_ids[:, -1:]], axis=-1) + model_kwargs["role_ids"] = paddle.concat( + [role_ids, role_ids[:, -1:]], axis=-1 + ) if self.config.rope_3d: - assert "position_ids" in model_kwargs, "position_ids must be provided if rope_3d is on" + assert ( + "position_ids" in model_kwargs + ), "position_ids must be provided if rope_3d is on" position_ids = model_kwargs["position_ids"] model_kwargs["position_ids"] = paddle.concat( [ @@ -2339,11 +2657,19 @@ def forward( inbatch_pack_offset=None, loss_mask=None, ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ernie( input_ids, @@ -2379,3 +2705,29 @@ def forward( assert labels is not None loss, loss_sum = self.criterion(logits, labels) return loss, loss_sum + + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + + import re + + def increment_expert_number(s, increment): + def replace(match): + original_number = int(match.group(0)) + new_number = original_number + increment + return str(new_number) + + return re.sub(r"(?<=experts\.)\d+", replace, s) + + renamed_sharded_state_dict = {} + for k, v in sharded_state_dict.items(): + global_expert_id_offset = getattr(v, "global_expert_id_offset", None) + if global_expert_id_offset is not None: + new_key = increment_expert_number(k, global_expert_id_offset) + v.key = new_key + delattr(v, "global_expert_id_offset") + renamed_sharded_state_dict[new_key] = v + else: + renamed_sharded_state_dict[k] = v + + return renamed_sharded_state_dict diff --git a/examples/pre-training/models/ernie/modeling_moe.py b/examples/pre-training/models/ernie/modeling_moe.py index eff1fb6b7..a5e7d3527 100644 --- a/examples/pre-training/models/ernie/modeling_moe.py +++ b/examples/pre-training/models/ernie/modeling_moe.py @@ -900,7 +900,6 @@ def __init__(self, config, layer_idx): ) self.use_rms_qkv_recompute = config.use_rms_qkv_recompute if config.use_rms_qkv_recompute is True: - assert config.use_rmsnorm is True and config.fuse_rms_norm is True assert config.fuse_linear is True and config.use_bias is False @@ -1512,6 +1511,13 @@ def get_tensor_parallel_split_mappings(num_layers): "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + # add mtp block actions + "mtp_block.0.self_attn.qkv_proj.weight": qkv_fn, + "mtp_block.0.mlp.up_gate_proj.weight": partial( + fn, is_column=True, is_naive_2fuse=True + ), + "mtp_block.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "mtp_block.0.mlp.down_proj.weight": partial(fn, is_column=False), } if config.use_bias: base_actions.update( @@ -1522,6 +1528,12 @@ def get_tensor_parallel_split_mappings(num_layers): ), "layers.0.mlp.down_proj.bias": lambda x: x, "lm_head.bias": partial(fn, is_column=True), + # add mtp block bias actions + "mtp_block.0.self_attn.qkv_proj.bias": qkv_fn, + "mtp_block.0.mlp.up_gate_proj.bias": partial( + fn, is_column=True, is_naive_2fuse=True + ), + "mtp_block.0.mlp.down_proj.bias": lambda x: x, } ) else: @@ -1534,6 +1546,14 @@ def get_tensor_parallel_split_mappings(num_layers): "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + # add mtp block actions + "mtp_block.0.self_attn.q_proj.weight": partial(fn, is_column=True), + "mtp_block.0.self_attn.k_proj.weight": partial(fn, is_column=True), + "mtp_block.0.self_attn.v_proj.weight": partial(fn, is_column=True), + "mtp_block.0.mlp.gate_proj.weight": partial(fn, is_column=True), + "mtp_block.0.mlp.up_proj.weight": partial(fn, is_column=True), + "mtp_block.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "mtp_block.0.mlp.down_proj.weight": partial(fn, is_column=False), } if config.use_bias: base_actions.update( @@ -1551,18 +1571,34 @@ def get_tensor_parallel_split_mappings(num_layers): "layers.0.mlp.up_proj.bias": partial(fn, is_column=True), "layers.0.mlp.down_proj.bias": lambda x: x, "lm_head.bias": partial(fn, is_column=True), + # add mtp block bias actions + "mtp_block.0.self_attn.q_proj.bias": partial( + fn, is_column=True + ), + "mtp_block.0.self_attn.k_proj.bias": partial( + fn, is_column=True + ), + "mtp_block.0.self_attn.v_proj.bias": partial( + fn, is_column=True + ), + "mtp_block.0.mlp.gate_proj.bias": partial(fn, is_column=True), + "mtp_block.0.mlp.up_proj.bias": partial(fn, is_column=True), + "mtp_block.0.mlp.down_proj.bias": lambda x: x, } ) - moe_in_mp = config.moe_group in {"mp", "model", "tp", "mpdp"} + + moe_in_mp = config.moe_group_name in {"mp", "model", "tp", "mpdp"} + for key, action in base_actions.items(): if "layers.0." in key: for i in range(num_layers): newkey = key.replace("layers.0.", f"layers.{i}.") - if config.moe_group in {"mpdp"}: + if config.moe_group_name in {"mpdp"}: final_actions[newkey] = lambda x: x else: final_actions[newkey] = action - if "mlp" in key and (i + 1) % config.moe_layer_interval == 0: + # only expand experts for non-MTP layers + if key.startswith("layers.0.mlp") and (i + 1) % config.moe_layer_interval == 0: moe_num_experts = config.moe_num_experts if moe_num_experts > 0: for expert_id in range(moe_num_experts): @@ -1602,6 +1638,11 @@ def get_tensor_parallel_split_mappings(num_layers): final_actions[key.replace("layers.0.", f"layers.{i}.")] = ( action ) + elif "mtp_block.0." in key: + depth = getattr(config, "multi_token_pred_depth", 0) or 0 + for d in range(depth): + newkey = key.replace("mtp_block.0.", f"mtp_block.{d}.") + final_actions[newkey] = action else: final_actions[key] = action return final_actions @@ -1695,13 +1736,12 @@ def _init_weights(self, layer): @register_base_model class ErnieModel(ErniePretrainedModel): - def __init__(self, config: ErnieMoEConfig): + def __init__(self, config: ErnieMoEConfig): if config.moe_group in {"mp", "model", "tp", "mpdp"}: logger.info( f"disable FFN tensor model parallel, moe-group={config.moe_group}" ) config.disable_ffn_model_parallel = True - config.moe_group = _parse_moe_group(config.moe_group) config.moe_world_size = dist.get_world_size(config.moe_group) @@ -2214,7 +2254,7 @@ def __init__(self, config): self.lm_head = ErnieMoELMHead(config) self.criterion = ErniePretrainingCriterion(config) - self.tie_weights() + # self.tie_weights() if self.config.fuse_rms_norm: logger.info("Use fusedRMSNorm") @@ -2471,3 +2511,30 @@ def forward( router_loss = None assert labels is not None return self.criterion(logits, labels, router_loss, mtp_logits) + + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + + + + import re + def increment_expert_number(s, increment): + def replace(match): + original_number = int(match.group(0)) + new_number = original_number + increment + return str(new_number) + return re.sub(r'(?<=experts\.)\d+', replace, s) + + + renamed_sharded_state_dict = {} + for k, v in sharded_state_dict.items(): + global_expert_id_offset = getattr(v, 'global_expert_id_offset', None) + if global_expert_id_offset is not None: + new_key = increment_expert_number(k, global_expert_id_offset) + v.key = new_key + delattr(v, 'global_expert_id_offset') + renamed_sharded_state_dict[new_key] = v + else: + renamed_sharded_state_dict[k] = v + + return renamed_sharded_state_dict \ No newline at end of file diff --git a/examples/pre-training/models/ernie/modeling_pp.py b/examples/pre-training/models/ernie/modeling_pp.py index 6a21cceea..615e95396 100644 --- a/examples/pre-training/models/ernie/modeling_pp.py +++ b/examples/pre-training/models/ernie/modeling_pp.py @@ -154,7 +154,9 @@ def forward(self, args): axis=1, ) if self.sequence_parallel: - inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) + inputs_embeds_mtp = inputs_embeds_mtp.reshape( + [-1, inputs_embeds_mtp.shape[-1]] + ) inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) mtp_emb_res.append(inputs_embeds_mtp) res = paddle.concat(mtp_emb_res) @@ -178,7 +180,10 @@ def forward(self, args): ret += (position_ids.clone(),) if inbatch_pack_offset is not None: ret += (inbatch_pack_offset.clone(),) - if self.config.multi_token_pred_depth > 0 and not self.config.enable_mtp_magic_send: + if ( + self.config.multi_token_pred_depth > 0 + and not self.config.enable_mtp_magic_send + ): ret += (input_ids,) assert len(ret) == 2, "mtp only support one input which is input_ids" if len(ret) == 1: @@ -203,7 +208,9 @@ def forward(self, args): assert len(input_ids_for_mtp) > 0, "input_ids for mtp is empty" hidden_states = args[0] input_ids = input_ids_for_mtp.popleft() - input_embeds = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype) + input_embeds = self.embed_tokens(input_ids).astype( + self.embed_tokens.weight.dtype + ) return (hidden_states, input_embeds) @@ -225,7 +232,10 @@ def __init__(self, config, layer_idx, use_full_recompute=False): self.use_mem_eff_attn = config.use_mem_eff_attn def forward(self, args): - if self.config.multi_token_pred_depth > 0 and not self.config.enable_mtp_magic_send: + if ( + self.config.multi_token_pred_depth > 0 + and not self.config.enable_mtp_magic_send + ): res = args[0] tensor_list = paddle.split(res, self.config.multi_token_pred_depth + 1) inputs_embeds = tensor_list[-self.config.multi_token_pred_depth :] @@ -267,9 +277,13 @@ def forward(self, args): if "mod" == setting_type: assert isinstance(offload_value, (list, tuple)) v1, v2 = offload_value - offload_kwargs["offload_indices"] = [0] if self.layer_idx % v1 == v2 else [] + offload_kwargs["offload_indices"] = ( + [0] if self.layer_idx % v1 == v2 else [] + ) elif "layer_idxs" == setting_type: - offload_kwargs["offload_indices"] = [0] if self.layer_idx in offload_value else [] + offload_kwargs["offload_indices"] = ( + [0] if self.layer_idx in offload_value else [] + ) if offload_kwargs.get("offload_indices", []) and res is not None: inplace_offload(res) @@ -326,13 +340,17 @@ def __init__(self, config): def forward(self, args): if self.config.multi_token_pred_depth > 0: if self.config.enable_mtp_magic_send: - assert len(args) == self.config.multi_token_pred_depth + 1, "the length is not valid in mtp" + assert ( + len(args) == self.config.multi_token_pred_depth + 1 + ), "the length is not valid in mtp" mtp_outputs = [] for hidden_states in args: mtp_outputs.append(super().forward(hidden_states)) return mtp_outputs else: - tensor_list = paddle.split(args[0], self.config.multi_token_pred_depth + 1) + tensor_list = paddle.split( + args[0], self.config.multi_token_pred_depth + 1 + ) mtp_outputs = [] for hidden_states in tensor_list: mtp_outputs.append(super().forward(hidden_states)) @@ -370,16 +388,27 @@ def __init__(self, config): self.config = config if self.config.use_recompute_mtp: self.config.use_recompute = False - assert self.config.multi_token_pred_depth > 0, "Adding MTPLayer must assign value to multi_token_pred_depth" + assert ( + self.config.multi_token_pred_depth > 0 + ), "Adding MTPLayer must assign value to multi_token_pred_depth" self.mtp_block = paddle.nn.LayerList( - [ErnieDecoderLayer(config, layer_idx) for layer_idx in range(self.config.multi_token_pred_depth)] + [ + ErnieDecoderLayer(config, layer_idx) + for layer_idx in range(self.config.multi_token_pred_depth) + ] ) Norm = RMSNorm - self.mtp_hidden_norm = paddle.nn.LayerList([Norm(config) for _ in range(self.config.multi_token_pred_depth)]) - self.mtp_emb_norm = paddle.nn.LayerList([Norm(config) for _ in range(self.config.multi_token_pred_depth)]) + self.mtp_hidden_norm = paddle.nn.LayerList( + [Norm(config) for _ in range(self.config.multi_token_pred_depth)] + ) + self.mtp_emb_norm = paddle.nn.LayerList( + [Norm(config) for _ in range(self.config.multi_token_pred_depth)] + ) - LinearFN = paddle.incubate.nn.FusedLinear if config.fuse_linear else paddle.nn.Linear + LinearFN = ( + paddle.incubate.nn.FusedLinear if config.fuse_linear else paddle.nn.Linear + ) self.mtp_linear_proj = paddle.nn.LayerList( [ LinearFN( @@ -409,7 +438,9 @@ def forward_impl(self, *args): if self.config.enable_mtp_magic_send: assert isinstance(args, tuple), "Input for MTPLayer must be tuple" hidden_states, inputs_embeds = args - inputs_embeds_extra = inputs_embeds[:, -self.config.multi_token_pred_depth :, :] + inputs_embeds_extra = inputs_embeds[ + :, -self.config.multi_token_pred_depth :, : + ] inputs_embeds = inputs_embeds[:, : -self.config.multi_token_pred_depth, :] inputs_embeds_ori = inputs_embeds else: @@ -430,16 +461,22 @@ def forward_impl(self, *args): ) if self.config.sequence_parallel: - inputs_embeds_cur_depth = inputs_embeds_cur_depth.reshape([-1, inputs_embeds_cur_depth.shape[-1]]) + inputs_embeds_cur_depth = inputs_embeds_cur_depth.reshape( + [-1, inputs_embeds_cur_depth.shape[-1]] + ) inputs_embeds_cur_depth = ScatterOp.apply(inputs_embeds_cur_depth) else: inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] - inputs_embeds_cur_depth_norm = self.mtp_emb_norm[depth](inputs_embeds_cur_depth) + inputs_embeds_cur_depth_norm = self.mtp_emb_norm[depth]( + inputs_embeds_cur_depth + ) hidden_states_norm = self.mtp_hidden_norm[depth](hidden_states) inputs_embeds_cur_depth = self.mtp_linear_proj[depth]( - paddle.concat([inputs_embeds_cur_depth_norm, hidden_states_norm], axis=-1) + paddle.concat( + [inputs_embeds_cur_depth_norm, hidden_states_norm], axis=-1 + ) ) decoder_layer = self.mtp_block[depth] @@ -499,24 +536,38 @@ def init(self, config, *args, **kwargs): self._pp_to_single_mapping = None def add_sequential_layer(self, layer_desc, name_prefix=""): - self._sequential_layers.append({"layer": layer_desc, "name_prefix": name_prefix}) + self._sequential_layers.append( + {"layer": layer_desc, "name_prefix": name_prefix} + ) def get_sequential_layers(self): return [x["layer"] for x in self._sequential_layers] def get_sequential_name_prefixs(self): - return {str(index): x["name_prefix"] for index, x in enumerate(self._sequential_layers)} + return { + str(index): x["name_prefix"] + for index, x in enumerate(self._sequential_layers) + } def get_shardlayer_prefix(self, name_splited): - shared_layer_names = {s.layer_name for s in self._layers_desc if isinstance(s, SharedLayerDesc)} - assert name_splited[1] in shared_layer_names, f"The shared layer name {name_splited[1]} must be in prefixes!" + shared_layer_names = { + s.layer_name for s in self._layers_desc if isinstance(s, SharedLayerDesc) + } + assert ( + name_splited[1] in shared_layer_names + ), f"The shared layer name {name_splited[1]} must be in prefixes!" shared_layer_key = name_splited[1] for idx, layer in enumerate(self._layers_desc): - if isinstance(layer, SharedLayerDesc) and layer.layer_name == shared_layer_key: + if ( + isinstance(layer, SharedLayerDesc) + and layer.layer_name == shared_layer_key + ): if self.get_stage_from_index(idx) == self._stage_id: return self.get_sequential_name_prefixs()[str(idx)] - raise ValueError(f"The shared layer {shared_layer_key} must be in the current stage!") + raise ValueError( + f"The shared layer {shared_layer_key} must be in the current stage!" + ) def _set_pipeline_name_mapping(self, mappings=None): if mappings is not None: @@ -636,7 +687,9 @@ def _init_weights(self, layer): dtype = paddle.get_default_dtype() paddle.set_default_dtype("float32") layer.weight.set_value( - paddle.randn(layer.weight.shape, dtype=dtype).scale(self.config.initializer_range) + paddle.randn(layer.weight.shape, dtype=dtype).scale( + self.config.initializer_range + ) ) paddle.set_default_dtype(dtype) @@ -651,7 +704,9 @@ def _init_weights(self, layer): moe_num_experts = moe_num_experts[0] if self.config.moe_group_experts: layer.weight.set_value( - paddle.randn(layer.weight.shape, dtype=layer.weight.dtype).scale(self.config.initializer_range) + paddle.randn( + layer.weight.shape, dtype=layer.weight.dtype + ).scale(self.config.initializer_range) ) else: layer.weight.set_value( @@ -664,15 +719,17 @@ def _init_weights(self, layer): for i in range(1, len(self.config.moe_num_experts)): layer_weight = getattr(layer, f"weight_{i}") layer_weight.set_value( - paddle.randn(layer_weight.shape, dtype=layer_weight.dtype).scale( - self.config.initializer_range - ) + paddle.randn( + layer_weight.shape, dtype=layer_weight.dtype + ).scale(self.config.initializer_range) ) paddle.set_default_dtype(dtype) elif isinstance(layer, RotaryEmbedding): head_dim = self.config.hidden_size // self.config.num_attention_heads - inv_freq = 1.0 / (layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) + inv_freq = 1.0 / ( + layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) + ) t = np.arange(layer.max_position_embeddings, dtype="float32") freqs = np.einsum("i,j->ij", t, inv_freq) @@ -683,6 +740,46 @@ def _init_weights(self, layer): layer.cos_cached.set_value(cos_cached) layer.sin_cached.set_value(sin_cached) + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._pipeline_name_mapping is None: + self._set_pipeline_name_mapping() + + logger.info("====> debug PipelinePretrainedModel sharded state dict") + for k, v in self._pp_to_single_mapping.items(): + logger.info(f"==> {k}:{v}") + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + import re + + def increment_expert_number(s, increment): + def replace(match): + original_number = int(match.group(0)) + new_number = original_number + increment + return str(new_number) + + return re.sub(r"(?<=experts\.)\d+", replace, s) + + renamed_sharded_state_dict = {} + for k, v in sharded_state_dict.items(): + global_expert_id_offset = getattr(v, "global_expert_id_offset", None) + if global_expert_id_offset is not None: + new_key = increment_expert_number(k, global_expert_id_offset) + v.key = new_key + delattr(v, "global_expert_id_offset") + renamed_sharded_state_dict[new_key] = v + else: + renamed_sharded_state_dict[k] = v + + logger.info("====> debug PipelinePretrainedModel renamed sharded state dict") + for k, v in renamed_sharded_state_dict.items(): + logger.info(f"==> {k}:{v}") + return renamed_sharded_state_dict + def get_pp_vp_split_layers(config): hcg = fleet.get_hybrid_communicate_group() @@ -701,14 +798,19 @@ def get_pp_vp_split_layers(config): ) chunk_size = layer_num // (pp_size * vp_size) - chunk_list = [list(range(i * chunk_size, (i + 1) * chunk_size)) for i in range(pp_size * vp_size)] + chunk_list = [ + list(range(i * chunk_size, (i + 1) * chunk_size)) + for i in range(pp_size * vp_size) + ] stage_chunk_list = [[] for _ in range(pp_size)] for i in range(pp_size * vp_size): stage_chunk_list[i % pp_size].append(chunk_list[i]) if config.use_recompute_attn: - logger.error("selective recompute only support full recompute now, please set use_recompute_attn to False") + logger.error( + "selective recompute only support full recompute now, please set use_recompute_attn to False" + ) for i in range(pp_size): no_recompute_layer_num.extend(stage_chunk_list[i][-selective_no_recompute_num:]) @@ -764,7 +866,9 @@ def __init__( if config.moe_group in {"mp", "model", "tp", "mpdp"}: assert config.sequence_parallel - logger.info(f"disable FFN tensor model parallel, moe-group={config.moe_group}") + logger.info( + f"disable FFN tensor model parallel, moe-group={config.moe_group}" + ) config.disable_ffn_model_parallel = True config.moe_group = _parse_moe_group(config.moe_group) @@ -801,12 +905,18 @@ def _need_full_recompute(layer_idx): insert_empty_layer = config.insert_empty_layer if len(insert_empty_layer) > 0: - assert min(insert_empty_layer) >= 0, "cannot insert empty layer as first layer of the model" - assert max(insert_empty_layer) < config.num_hidden_layers, "empty layers location exceed the num layers" + assert ( + min(insert_empty_layer) >= 0 + ), "cannot insert empty layer as first layer of the model" + assert ( + max(insert_empty_layer) < config.num_hidden_layers + ), "empty layers location exceed the num layers" logger.info(f"use insert_empty_layer: {insert_empty_layer}") if config.multi_token_pred_depth == 0: - self.add_sequential_layer(LayerDesc(self.ErnieEmbeddingPipeClass, config=config), "ernie") + self.add_sequential_layer( + LayerDesc(self.ErnieEmbeddingPipeClass, config=config), "ernie" + ) else: if config.enable_mtp_magic_send: self.add_sequential_layer( @@ -819,9 +929,13 @@ def _need_full_recompute(layer_idx): "ernie.embed", ) else: - self.add_sequential_layer(LayerDesc(self.ErnieEmbeddingPipeClass, config=config), "ernie") + self.add_sequential_layer( + LayerDesc(self.ErnieEmbeddingPipeClass, config=config), "ernie" + ) - num_empty_layers = config.remove_tail_layer if isinstance(config.remove_tail_layer, int) else 1 + num_empty_layers = ( + config.remove_tail_layer if isinstance(config.remove_tail_layer, int) else 1 + ) for i in range(config.num_hidden_layers - num_empty_layers): self.add_sequential_layer( LayerDesc( @@ -851,7 +965,9 @@ def _need_full_recompute(layer_idx): ), "embed_share", ) - self.add_sequential_layer(LayerDesc(self.MTPLayerClass, config=config), "ernie") + self.add_sequential_layer( + LayerDesc(self.MTPLayerClass, config=config), "ernie" + ) num_empty_layers = num_empty_layers - config.multi_token_pred_depth if config.remove_tail_layer: @@ -888,14 +1004,22 @@ def _need_full_recompute(layer_idx): "ernie.norm", ) - self.add_sequential_layer(LayerDesc(self.ErnieMoELMHeadPipeClass, config=config), "lm_head") + self.add_sequential_layer( + LayerDesc(self.ErnieMoELMHeadPipeClass, config=config), "lm_head" + ) recompute_interval = 0 seg_method = "layer:ErnieDecoderLayer|EmptyLayer|MTPLayer" - if config.num_hidden_layers % fleet.get_hybrid_communicate_group().topology().get_dim_size("pipe") != 0: + if ( + config.num_hidden_layers + % fleet.get_hybrid_communicate_group().topology().get_dim_size("pipe") + != 0 + ): seg_method = "uniform" - logger.info(f"using recompute_interval={recompute_interval}, seg_method={seg_method}") + logger.info( + f"using recompute_interval={recompute_interval}, seg_method={seg_method}" + ) PipelineLayer.__init__( self, @@ -925,7 +1049,9 @@ def rename_model_params(self, func): def fp8_quant_weight(self): with paddle.no_grad(): for i, layer in self._sub_layers.items(): - if isinstance(layer, ErnieDecoderLayer) and hasattr(layer, "fp8_quant_weight"): + if isinstance(layer, ErnieDecoderLayer) and hasattr( + layer, "fp8_quant_weight" + ): layer.fp8_quant_weight() def _post_init(self, original_init, *args, **kwargs): @@ -981,12 +1107,17 @@ def set_state_dict(self, state_dict, *args, **kwargs): if k not in self._pipeline_name_mapping: continue state_dict[self._pipeline_name_mapping[k]] = v - missing_keys, mismatch_keys = super().set_state_dict(state_dict, *args, **kwargs) + missing_keys, mismatch_keys = super().set_state_dict( + state_dict, *args, **kwargs + ) missing_shared_keys = self._check_shared_model_state() tmp_missing_keys = [] for key in missing_keys: - if key in missing_shared_keys and missing_shared_keys[key] not in missing_keys: + if ( + key in missing_shared_keys + and missing_shared_keys[key] not in missing_keys + ): continue tmp_missing_keys.append(key) missing_keys = tmp_missing_keys diff --git a/examples/pre-training/models/moe/moe_layer.py b/examples/pre-training/models/moe/moe_layer.py index 498c577eb..4d26a35d9 100644 --- a/examples/pre-training/models/moe/moe_layer.py +++ b/examples/pre-training/models/moe/moe_layer.py @@ -490,6 +490,7 @@ def __init__( p.no_sync = not (self.is_mp_moe or is_dummy_moe) logger.info(f"expert no-sync={p.no_sync}-{p.name}") if self.is_mp_moe or self.is_ep_moe: + p.mp_moe = True p.is_distributed = True expert_color = None @@ -498,11 +499,11 @@ def __init__( fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() ) expert_color = {"color": "moe_expert", "group": moe_grad_group} - elif ( - self.config.offline_quant_expert_weight - and self.config.clear_origin_weight_when_offline_quant - ): - expert_color = {"color": "moe_expert"} + # elif ( + # self.config.offline_quant_expert_weight + # and self.config.clear_origin_weight_when_offline_quant + # ): + # expert_color = {"color": "moe_expert"} if expert_color is not None: for p in self.experts.parameters(): @@ -1133,7 +1134,17 @@ def forward( orig_shape[:-1] + [combined_output.shape[-1]] ) return combined_output, combine_weights, router_loss2, gate_logits - + + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + sharded_state_dict = super().sharded_state_dict(structured_name_prefix) + global_expert_id_offset = self.group.rank * self.num_local_experts + for k,v in sharded_state_dict.items(): + v.global_expert_id_offset = global_expert_id_offset + sharded_state_dict[k] = v + return sharded_state_dict class FP8FusedWLCHFunc(paddle.autograd.PyLayer): @staticmethod diff --git a/examples/pre-training/models/sequence_parallel_utils.py b/examples/pre-training/models/sequence_parallel_utils.py index dfad16368..4a230f857 100644 --- a/examples/pre-training/models/sequence_parallel_utils.py +++ b/examples/pre-training/models/sequence_parallel_utils.py @@ -31,7 +31,7 @@ from paddle.incubate.tensor.manipulation import create_async_load from paddle.nn import functional as F from paddle.nn.layer.layers import Layer - +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import build_sharded_state_dict try: from paddle.nn.functional import all_gather_gemm, flux, gemm_reduce_scatter except ImportError: @@ -449,7 +449,14 @@ def forward(self, x, use_comm=True): output = self.linear(input_parallel, self.weight, self.bias) return output - + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict( + state_dict, {"weight": 1}, structured_name_prefix + ) class MPScale(PyLayer): @staticmethod def forward(ctx, x, mp_degree): @@ -588,3 +595,12 @@ def forward(self, x): else: output = self.linear(input_parallel, self.weight, self.bias) return output + + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict( + state_dict, {"weight": 0}, structured_name_prefix + ) \ No newline at end of file