1111
1212from colossalai .interface import ModelWrapper
1313from colossalai .utils import get_non_persistent_buffers_set
14+ from colossalai .shardformer .layer .parallel_module import ParallelModule
15+ from contextlib import contextmanager
1416
1517from .index_file import CheckpointIndexFile
1618from .utils import (
3234MODEL_META_PREFIX = "pytorch_model-meta-dist-"
3335MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
3436SHARD_META_SUFFIX = ".index.json"
37+ UNSHARD_META_SUFFIX = ".json"
3538
3639
37- def dist_model_state_dict (model : nn .Module , prefix : str = "" , keep_vars : bool = False ):
38- destination = dict ()
39- # Save parameters.
40- for name , param in model .named_parameters ():
41- if param is None :
42- continue
43- destination [prefix + name ] = param
44- # Save buffers.
45- non_persist_buffers_set = get_non_persistent_buffers_set (model )
46- for name , buf in model .named_buffers ():
47- if buf is not None and name not in non_persist_buffers_set :
48- buffer = buf if keep_vars else buf .detach ()
49- destination [prefix + name ] = buffer
50-
51- # Save extra states.
52- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
53- if (
54- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
55- is not torch .nn .Module .get_extra_state
56- ):
57- extra_state = model .get_extra_state ()
58- destination [extra_state_key ] = extra_state
59- return destination
60-
61-
62- def load_state_dict_into_dist_model (
63- model : nn .Module , state_dict : Dict , prefix : str = "" , keep_vars : bool = False , strict : bool = False
64- ):
65- destination = dict ()
66- # Save parameters.
67- for name , param in model .named_parameters ():
68- if param is None :
69- continue
70- with torch .no_grad ():
71- param .copy_ (state_dict [prefix + name ])
72- # Save buffers.
73- non_persist_buffers_set = get_non_persistent_buffers_set (model )
74- for name , buf in model .named_buffers ():
75- if buf is not None and name not in non_persist_buffers_set :
76- with torch .no_grad ():
77- buf .copy_ (state_dict [prefix + name ])
78-
79- # Save extra states.
80- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
81- if (
82- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
83- is not torch .nn .Module .get_extra_state
84- ):
85- extra_state = model .get_extra_state ()
86- with torch .no_grad ():
87- extra_state .copy_ (state_dict [extra_state_key ])
88- return destination
40+ @contextmanager
41+ def RestoreDefaultStateDictBehavior (model ):
42+ original_methods = {}
43+ for name , module in model .named_modules ():
44+ if isinstance (module , ParallelModule ):
45+ original_methods [module ] = (module ._save_to_state_dict , module ._load_from_state_dict )
46+ module ._save_to_state_dict = nn .Module ._save_to_state_dict .__get__ (module , nn .Module )
47+ module ._load_from_state_dict = nn .Module ._load_from_state_dict .__get__ (module , nn .Module )
48+ try :
49+ yield model
50+ finally :
51+ for module , original_method in original_methods .items ():
52+ module ._save_to_state_dict , module ._load_from_state_dict = original_method
53+
8954
9055
9156def create_model_metadata (
92- model : nn . Module ,
57+ model : ModelWrapper ,
9358 prefix : str = "" ,
94- tp_size = None ,
95- tp_rank = None ,
59+ tp_size : int = None ,
60+ tp_rank : int = None ,
61+ zero_size : int = None ,
62+ zero_rank : int = None ,
9663):
9764 param_origin_shape = model .param_origin_shape
9865 model = model .unwrap ()
@@ -105,7 +72,7 @@ def create_model_metadata(
10572 tp_partition_dim = search_tp_partition_dim (
10673 current_shape = param .shape , original_shape = original_shape , tp_size = tp_size
10774 )
108- model_metadata [prefix + name ]["offsets" ] = torch . zeros ( len (original_shape ), dtype = torch . int )
75+ model_metadata [prefix + name ]["offsets" ] = [ 0 ] * len (original_shape )
10976 model_metadata [prefix + name ]["lengths" ] = list (param .shape )
11077 model_metadata [prefix + name ]["global_shape" ] = list (original_shape )
11178 if tp_partition_dim is not None :
@@ -257,119 +224,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file):
257224 return False
258225
259226
260- def dist_model_sharder (
261- model : nn .Module ,
262- prefix : str = "" ,
263- keep_vars : bool = False ,
264- size_per_shard : int = 1024 ,
265- pinned_state_dicts : Optional [Dict [str , torch .Tensor ]] = None ,
266- ) -> Iterator [Tuple [OrderedDict , int ]]:
267- # An internel method that breaks state_dict of model into shards within limited size.
268-
269- state_dict_sharder = StateDictSharder (size_per_shard )
270-
271- # Save parameters.
272- for name , param in model .named_parameters ():
273- if param is None :
274- continue
275- if pinned_state_dicts is not None :
276- if (prefix + name ) not in pinned_state_dicts :
277- pinned_state_dicts [prefix + name ] = torch .empty_like (param , pin_memory = True , device = "cpu" )
278- pinned_state_dicts [prefix + name ].copy_ (param )
279- param = pinned_state_dicts [prefix + name ]
280- block , block_size = state_dict_sharder .append_param (prefix + name , param )
281- if block is not None :
282- yield block , block_size
283-
284- # Save buffers.
285- non_persist_buffers_set = get_non_persistent_buffers_set (model )
286- for name , buf in model .named_buffers ():
287- if buf is not None and name not in non_persist_buffers_set :
288- buffer = buf if keep_vars else buf .detach ()
289- if pinned_state_dicts is not None :
290- if (prefix + name ) not in pinned_state_dicts :
291- pinned_state_dicts [prefix + name ] = torch .empty_like (buffer , pin_memory = True , device = "cpu" )
292- pinned_state_dicts [prefix + name ].copy_ (buffer )
293- buffer = pinned_state_dicts [prefix + name ]
294- block , block_size = state_dict_sharder .append_param (prefix + name , buffer )
295- if block is not None :
296- yield block , block_size
297-
298- # Save extra states.
299- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
300- if (
301- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
302- is not torch .nn .Module .get_extra_state
303- ):
304- extra_state = model .get_extra_state ()
305- if pinned_state_dicts is not None :
306- if extra_state_key not in pinned_state_dicts :
307- pinned_state_dicts [extra_state_key ] = torch .empty_like (extra_state , pin_memory = True , device = "cpu" )
308- pinned_state_dicts [extra_state_key ].copy_ (extra_state )
309- extra_state = pinned_state_dicts [extra_state_key ]
310- block , block_size = state_dict_sharder .append_param (extra_state_key , extra_state )
311- if block is not None :
312- yield block , block_size
313-
314- # Return the last block in sharder.
315- yield state_dict_sharder .current_block , state_dict_sharder .current_block_size
316-
317-
318- def save_dist_unshard_model (
319- model : ModelWrapper ,
320- model_metadata : Dict ,
321- checkpoint : str ,
322- use_safetensors : bool ,
323- use_async : bool = False ,
324- dist_id = 0 ,
325- pinned_state_dicts = None ,
326- ):
327- """
328- Save model state dict to a single file with given checkpointing path.
329-
330- Args:
331- model (nn.Module): Model on local device to be saved.
332- checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
333- gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
334- use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
335- use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
336- """
337-
338- model = model .unwrap ()
339-
340- # The logic of collecting parameter shards along tp degree
341- # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
342- state_dict = dist_model_state_dict (model )
343-
344- Path (checkpoint ).mkdir (parents = True , exist_ok = True )
345- file_name = f"{ MODEL_WEIGHT_PREFIX } { dist_id :05d} .bin"
346- if use_async :
347- file_name = file_name .replace (".bin" , ".safetensors" )
348- checkpoint_file = os .path .join (checkpoint , file_name )
349- metadata_file = os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} .json" )
350- save_metadata (model_metadata , metadata_file , file_name )
351-
352- if use_async :
353- from colossalai .utils .safetensors import save
354-
355- if id (model ) not in pinned_state_dicts :
356- pinned_state_dicts [id (model )] = create_pinned_state_dict (state_dict )
357- for name , param in state_dict .items ():
358- pinned_state_dicts [id (model )][name ].copy_ (param )
359- state_dict [name ] = pinned_state_dicts [id (model )][name ]
360- writer = save (path = checkpoint_file , state_dict = state_dict )
361- return writer
362- else :
363- save_state_dict (state_dict , checkpoint_file , use_safetensors )
364- return None
365-
366-
367227def load_dist_model (
368- model : ModelWrapper ,
369228 model_metadata : Dict ,
370229 checkpoint : str ,
371- low_cpu_mem_mode : bool = True ,
372- num_threads : int = 1 ,
373230):
374231 """
375232 Load model from a single file with the given path of checkpoint.
@@ -380,10 +237,6 @@ def load_dist_model(
380237 strict (bool, optional): For name matching during loading state_dict. Defaults to False.
381238 This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
382239 """
383-
384- model_before_wrapping = model
385- model = model .unwrap ()
386-
387240 metadata_loaded = load_metadata (checkpoint )
388241
389242 load_files = {}
@@ -420,92 +273,14 @@ def load_dist_model(
420273 )
421274 state_dict [key ] = state
422275
423- if not low_cpu_mem_mode :
424- state_dict = create_pinned_state_dict (state_dict , empty = False , num_threads = num_threads )
425-
426- load_state_dict_into_dist_model (model = model , state_dict = state_dict )
427-
428- # Update master params if mixed-precision training is enabled.
429- model_before_wrapping .update_master_params ()
430-
276+ return state_dict
431277
432- def save_dist_sharded_model (
433- model : ModelWrapper ,
434- model_metadata : Dict ,
435- checkpoint : str ,
436- prefix : Optional [str ] = None ,
437- size_per_shard : int = 1024 ,
438- use_safetensors : bool = False ,
439- use_async : bool = False ,
440- dist_id : int = 0 ,
441- pinned_state_dicts = None ,
442- ) -> None :
443- """
444- Save sharded model checkpoint under the given checkpointing path.
445- The following files will be created under the path:
446- - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
447- - Multiple files that store state tensors of models.
448- If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
449- If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
450-
451-
452- Args:
453- model (nn.Module): Model on local device to be saved.
454- checkpoint (str): Checkpointing path which should be a directory path.
455- gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
456- prefix (str, optional): Perfix of file to save. Defaults to None.
457- size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
458- use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
459- use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
460- """
461-
462- model = model .unwrap ()
463-
464- if os .path .isfile (checkpoint ):
465- logging .error (f"Provided path ({ checkpoint } ) should be a directory, not a file" )
466- return
467-
468- Path (checkpoint ).mkdir (parents = True , exist_ok = True )
469- # Devices along the same dp_group share the same copies of model.
470- # So only let the device with dp_rank == 0 and sp_rank == 0 save the model.
471-
472- if use_async :
473- if id (model ) not in pinned_state_dicts :
474- pinned_state_dicts [id (model )] = {}
475- pinned_state_dicts = pinned_state_dicts [id (model )]
476- else :
477- pinned_state_dicts = None
478- state_dict_shard = dist_model_sharder (model , size_per_shard = size_per_shard , pinned_state_dicts = pinned_state_dicts )
479- weights_name , _ = get_model_base_filenames (prefix , use_safetensors )
480- index_file = CheckpointIndexFile (checkpoint )
481-
482- # Manage filenames of sharded weights and index file for each pipeline stage.
278+ def get_dist_files_name (weights_name , dist_id ):
483279 weights_name = weights_name .replace (".bin" , f"-dist-{ dist_id :05d} -shard.bin" )
484280 weights_name = weights_name .replace (".safetensors" , f"-dist-{ dist_id :05d} -shard.safetensors" )
485- metadata_file = os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { SHARD_META_SUFFIX } " )
486- async_writers = []
487- if use_async :
488- total_size , writers = async_save_state_dict_shards (
489- sharded_state_dict = state_dict_shard ,
490- checkpoint = checkpoint ,
491- index_file = index_file ,
492- base_filename = weights_name ,
493- is_master = True ,
494- state_preprocess = False ,
495- )
496- async_writers .extend (writers )
497- else :
498- total_size = save_state_dict_shards (
499- sharded_state_dict = state_dict_shard ,
500- checkpoint = checkpoint ,
501- index_file = index_file ,
502- base_filename = weights_name ,
503- is_master = True ,
504- use_safetensors = use_safetensors ,
505- use_pp_format = True ,
506- )
507- for k , _ in model_metadata .items ():
508- model_metadata [k ]["file" ] = index_file .get_checkpoint_file (k )
281+ return weights_name
509282
510- save_metadata (model_metadata , metadata_file , total_size = total_size )
511- return async_writers
283+ def get_dist_meta_file_name (checkpoint , dist_id , use_safetensors ):
284+ if use_safetensors :
285+ return os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { SHARD_META_SUFFIX } " )
286+ return os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { UNSHARD_META_SUFFIX } " )
0 commit comments