3030import lr
3131from paddle .distributed import fleet
3232from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
33+ from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer import DygraphShardingOptimizer
3334
3435MODEL_CLASSES = {
3536 "gpt" : (GPTForPretraining , GPTTokenizer ),
3637 "gpt-cn" : (GPTForPretraining , GPTChineseTokenizer ),
3738}
3839
3940
40- def set_hyrbid_parallel_seed (basic_seed , dp_rank , mp_rank , pp_rank ):
41+ def set_hyrbid_parallel_seed (basic_seed , data_world_rank , mp_rank , pp_rank ):
4142 assert args .device != "cpu"
4243
43- random .seed (basic_seed + dp_rank )
44- np .random .seed (basic_seed + dp_rank )
45- paddle .seed (basic_seed + dp_rank )
44+ random .seed (basic_seed + data_world_rank )
45+ np .random .seed (basic_seed + data_world_rank )
46+ paddle .seed (basic_seed + data_world_rank )
4647
4748 # local_seed/ global_seed is used to control dropout in ModelParallel
4849 local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000
49- global_seed = basic_seed + dp_rank
50+ global_seed = basic_seed + data_world_rank
5051 tracker = get_rng_state_tracker ()
5152 tracker .add ('global_seed' , global_seed )
5253 tracker .add ('local_seed' , local_seed )
@@ -92,14 +93,18 @@ def do_train(args):
9293 strategy .hybrid_configs = {
9394 "dp_degree" : args .dp_degree ,
9495 "mp_degree" : args .mp_degree ,
95- "pp_degree" : args .pp_degree
96+ "pp_degree" : args .pp_degree ,
97+ "sharding_degree" : args .sharding_degree
9698 }
9799
98100 strategy .pipeline_configs = {
99101 "accumulate_steps" : args .local_batch_size // args .micro_batch_size ,
100102 "micro_batch_size" : args .micro_batch_size
101103 }
102104
105+ # set control in tensor parallel
106+ strategy .tensor_parallel_configs = {"tensor_init_seed" : args .seed }
107+
103108 fleet .init (is_collective = True , strategy = strategy )
104109
105110 # obtain rank message of hybrid parallel
@@ -108,10 +113,15 @@ def do_train(args):
108113 mp_rank = hcg .get_model_parallel_rank ()
109114 pp_rank = hcg .get_stage_id ()
110115 dp_rank = hcg .get_data_parallel_rank ()
116+ sharding_rank = hcg .get_sharding_parallel_rank ()
117+
118+ sharding_size = hcg .get_sharding_parallel_world_size ()
119+ data_world_rank = dp_rank * sharding_size + sharding_rank
120+ data_world_size = args .dp_degree * args .sharding_degree
111121 local_rank = int (os .getenv ("PADDLE_RANK_IN_NODE" , 0 ))
112122
113123 # seed control in hybrid parallel
114- set_hyrbid_parallel_seed (args .seed , dp_rank , mp_rank , pp_rank )
124+ set_hyrbid_parallel_seed (args .seed , data_world_rank , mp_rank , pp_rank )
115125
116126 default_global_tokens_num = args .global_batch_size * args .max_seq_len
117127
@@ -183,15 +193,31 @@ def do_train(args):
183193 if not any (nd in n for nd in ["bias" , "norm" ])
184194 ]
185195
186- optimizer = paddle .optimizer .AdamW (
187- learning_rate = lr_scheduler if lr_scheduler is not None else args .max_lr ,
188- beta1 = args .adam_beta1 ,
189- beta2 = args .adam_beta2 ,
190- epsilon = args .adam_epsilon ,
191- parameters = model .parameters (),
192- weight_decay = args .weight_decay ,
193- grad_clip = clip ,
194- apply_decay_param_fun = lambda x : x in decay_params )
196+ if args .sharding_degree > 1 :
197+ optimizer = DygraphShardingOptimizer (
198+ hcg = fleet .get_hybrid_communicate_group (),
199+ user_defined_strategy = strategy ,
200+ params = model .parameters (),
201+ inner_optimizer_class = paddle .optimizer .AdamW ,
202+ learning_rate = lr_scheduler
203+ if lr_scheduler is not None else args .max_lr ,
204+ beta1 = args .adam_beta1 ,
205+ beta2 = args .adam_beta2 ,
206+ epsilon = args .adam_epsilon ,
207+ weight_decay = args .weight_decay ,
208+ grad_clip = clip ,
209+ apply_decay_param_fun = lambda x : x in decay_params )
210+ else :
211+ optimizer = paddle .optimizer .AdamW (
212+ learning_rate = lr_scheduler
213+ if lr_scheduler is not None else args .max_lr ,
214+ beta1 = args .adam_beta1 ,
215+ beta2 = args .adam_beta2 ,
216+ epsilon = args .adam_epsilon ,
217+ parameters = model .parameters (),
218+ weight_decay = args .weight_decay ,
219+ grad_clip = clip ,
220+ apply_decay_param_fun = lambda x : x in decay_params )
195221
196222 if paddle .distributed .get_world_size () > 1 :
197223 model = fleet .distributed_model (model )
@@ -227,8 +253,8 @@ def do_train(args):
227253 args ,
228254 data_file ,
229255 local_rank = local_rank ,
230- data_world_size = args . dp_degree ,
231- data_world_rank = dp_rank ,
256+ data_world_size = data_world_size ,
257+ data_world_rank = data_world_rank ,
232258 eos_id = tokenizer .eos_token_id )
233259 # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
234260 # many times. and start a new random dataloader.
@@ -309,6 +335,7 @@ def do_train(args):
309335 args .eval_iters , log_writer , global_step ,
310336 epoch , "valid" )
311337
338+ # TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
312339 # only dp_rank = 0 save model
313340 if (global_step % args .save_steps == 0 or
314341 global_step >= args .max_steps ) and dp_rank == 0 :
@@ -322,24 +349,25 @@ def do_train(args):
322349 logger .info ("Save model to %s" % output_dir )
323350
324351 if args .pp_degree > 1 :
325- model_to_save .save_state_dict (output_dir )
326- if mp_rank * pp_rank == 1 :
352+ if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0 :
327353 tokenizer .save_pretrained (output_dir )
354+ model_to_save .save_state_dict (output_dir )
328355 paddle .save (
329356 optimizer .state_dict (),
330357 os .path .join (
331358 output_dir ,
332- "model_state_mp_{:0>2d}_pp_{:0>2d}.pdopt" .
333- format (mp_rank , pp_rank )))
359+ "model_state_mp_{:0>2d}_sharding_{:0>2d} _pp_{:0>2d}.pdopt" .
360+ format (mp_rank , sharding_rank , pp_rank )))
334361 else :
335- path = os .path .join (output_dir ,
336- 'model_{:0>2d}' .format (mp_rank ))
337- os .makedirs (path , exist_ok = True )
338- model_to_save .save_pretrained (path )
339-
340- paddle .save (optimizer .state_dict (),
341- os .path .join (path , "model_state.pdopt" ))
342- tokenizer .save_pretrained (path )
362+ if mp_rank == 0 and sharding_rank == 0 :
363+ tokenizer .save_pretrained (output_dir )
364+ model_to_save .save_pretrained (output_dir )
365+ paddle .save (
366+ optimizer .state_dict (),
367+ os .path .join (
368+ output_dir ,
369+ "model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt" .
370+ format (mp_rank , sharding_rank )))
343371
344372 if global_step >= args .max_steps :
345373 run_evaluate (args , test_data_loader , model , criterion ,
0 commit comments