From d875317f4df64005173b0a2a055be1d7ea0813e6 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 18 Jun 2025 14:49:40 +0000 Subject: [PATCH] Simplify and add README Signed-off-by: Olatunji Ruwase --- deepnvme/model_checkpoint/README | 25 +++++++++++++++++ .../model_checkpoint/deepspeed_save_model.py | 28 +++++++++++++------ deepnvme/model_checkpoint/save_model_utils.py | 4 +-- deepnvme/model_checkpoint/torch_save_model.py | 9 ++---- .../model_checkpoint/torch_save_tensor.py | 16 +++-------- deepnvme/model_checkpoint/torch_save_utils.py | 14 +++++++--- 6 files changed, 63 insertions(+), 33 deletions(-) create mode 100644 deepnvme/model_checkpoint/README diff --git a/deepnvme/model_checkpoint/README b/deepnvme/model_checkpoint/README new file mode 100644 index 000000000..3d4c95e9d --- /dev/null +++ b/deepnvme/model_checkpoint/README @@ -0,0 +1,25 @@ +[FastPersist](https://arxiv.org/abs/2406.13768) is an optimization technique that leverages NVMe storage to accelerate model checkpointing. This folder contains micro-benchmarks and instructions for demonstrating FastPersist. + +## Enabling FastPersist Optimizations ## +FastPersist is designed to integrate with torch checkpointing and has been validated with torch version 2.6.0. This requires slight modifications to torch serialization, and for convenience we provide [original](torch/serialization_orig_v2.6.0.py) and [patched](torch/serialization_fast_v2.6.0.py) versions of serialization.py. Thus, to demonstrate FastPersist performance you need to overwrite `torch/serialization.py` in your torch installation with the patched version. + +## Available Micro-benchmarks ## +This folder contains three different micro-benchmarks that are implemented by the following scripts: +1. torch_save_tensor.py: Serialize a raw pytorch tensor to disk using `torch.save()` integration. +2. torch_save_model.py: Serialize a HF model to disk using `torch.save()` integration. +3. deepspeed_save_model.py: Serialize a HF model to disk using DeepSped integration. + +Each script provides a `--help` option to examine the available configurations. The scripts are written for single-process execution and so can be launched using `python`. + +As an example, the performance of using the `torch.save()` integration of checkpointing HF phi-3-mini model from GPU memory can be measured as follows: +``` +python torch_save_model.py --model phi3 --folder /mnt/nvme0 --gpu +``` + +The script executes and reports the performance of the checkpointing workload using different mechanisms including vanilla `torch.save()`, FastPersist with CPU bounce buffer, FastPersist with NVIDIA GDS, etc. You can find the respective performance by searching the generated log for lines similar to the following snippet. For this example, the results below, collected using eight PCI Gen4 NVMes RAID-0 (data-striped), show checkpointing throughputs of 0.69GB/sec and 17.75GB/sec for vanilla `torch.save()` (labelled test_save) and FastPersist with CPU bounce buffer (labelled test_ds_aio_fast_save) respectively. + +```bash +test_save -- 14.23 GB, 20.72 secs, 0.69 GB/s +test_ds_aio_fast_save -- 14.23 GB, 0.80 secs, 17.75 GB/s +``` + diff --git a/deepnvme/model_checkpoint/deepspeed_save_model.py b/deepnvme/model_checkpoint/deepspeed_save_model.py index ea97dd717..bafbe62c7 100644 --- a/deepnvme/model_checkpoint/deepspeed_save_model.py +++ b/deepnvme/model_checkpoint/deepspeed_save_model.py @@ -8,6 +8,7 @@ import deepspeed from deepspeed.accelerator import get_accelerator from save_model_utils import get_model, validate_arguments, parse_arguments +from torch_save_utils import load_io_ops def _get_ds_config(args, writer_type, use_gds): ds_config = { @@ -26,7 +27,7 @@ def _get_ds_config(args, writer_type, use_gds): } }, "checkpoint": { - "checkpoint_serialization": not args.legacy + "checkpoint_serialization": args.zipfile }, "aio": { "block_size": 8 * (1024**2), @@ -64,11 +65,9 @@ def _do_optimizer_step(ds_engine): def _free_ds_memory(ds_engine): - ds_engine.optimizer.optimizer = None - ds_engine.optimizer = None - ds_engine.module = None - ds_engine = None + ds_engine.destroy() del ds_engine + ds_engine = None gc.collect() get_accelerator().empty_cache() @@ -80,9 +79,11 @@ def test_save(tag, folder, model, args, writer_type): if args.zero_stage == 0: _do_optimizer_step(ds_engine) + import pdb; pdb.set_trace() st = time.time() ds_engine.save_checkpoint(save_dir=folder, tag=tag) write_sec = time.time() - st + import pdb; pdb.set_trace() _free_ds_memory(ds_engine) return write_sec @@ -107,8 +108,6 @@ def run(model, model_name, ckpt_name, args): folder = os.path.join(args.folder, ckpt_name, tag) if os.path.exists(folder): shutil.rmtree(folder, ignore_errors=True) - # if not os.path.exists(folder): - # os.makedirs(folder, exist_ok=True) write_sec = test_save(tag, folder, model, args, writer_type) ckpt_size = _get_folder_size(folder) gb_size = ckpt_size / (1024**3) @@ -118,19 +117,32 @@ def run(model, model_name, ckpt_name, args): ) print(f'*********************************************') +def init_torch_distributed(): + import torch.distributed as dist + from deepspeed.constants import TORCH_DISTRIBUTED_DEFAULT_PORT, CROSS_RANK, CROSS_SIZE + os.environ['MASTER_PORT'] = str(TORCH_DISTRIBUTED_DEFAULT_PORT) + os.environ['MASTER_ADDR'] = "localhost" + os.environ['LOCAL_RANK'] = str(0) + os.environ['WORLD_SIZE'] = str(1) + os.environ['CROSS_RANK'] = str(0) + os.environ['CROSS_SIZE'] = str(1) + dist.init_process_group(backend='nccl', rank=0, world_size=1) + + def main(): print( f'Performance test of deepspeed integration of fast model checkpointing.' ) print(f'torch version = {torch.__version__}') + init_torch_distributed() torch.manual_seed(42) np.random.seed(0) random.seed(0) args = parse_arguments() if not validate_arguments(args): quit() - + load_io_ops(args) model, model_name, ckpt_name = get_model(args.model) run(model, model_name, ckpt_name, args) diff --git a/deepnvme/model_checkpoint/save_model_utils.py b/deepnvme/model_checkpoint/save_model_utils.py index faf4fc5d8..be5c4d5bc 100644 --- a/deepnvme/model_checkpoint/save_model_utils.py +++ b/deepnvme/model_checkpoint/save_model_utils.py @@ -67,9 +67,9 @@ def parse_arguments(): default=0, help='Local rank' ) - parser.add_argument('--legacy', + parser.add_argument('--zipfile', action='store_true', - help='Use torch legacy save format') + help='Use torch zipfile save format') parser.add_argument('--optimizer', action='store_true', diff --git a/deepnvme/model_checkpoint/torch_save_model.py b/deepnvme/model_checkpoint/torch_save_model.py index f37d122be..9ac855ca6 100644 --- a/deepnvme/model_checkpoint/torch_save_model.py +++ b/deepnvme/model_checkpoint/torch_save_model.py @@ -2,11 +2,10 @@ import torch from torch.optim import Adam import os -from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save +from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save, load_io_ops from save_model_utils import get_model, validate_arguments, parse_arguments import deepspeed from deepspeed.accelerator import get_accelerator -import deepspeed.comm as dist def run(model, model_name, ckpt_name, args): @@ -23,8 +22,6 @@ def run(model, model_name, ckpt_name, args): continue file = os.path.join(args.folder, f'{tag}_{ckpt_name}.pt') print(f'checkpoint file = {file}') - if os.path.isfile(file): - os.remove(file) st = time.time() write_sec = fn(file, model, args) ckpt_size = os.path.getsize(file) @@ -59,8 +56,7 @@ def main(): args = parse_arguments() if not validate_arguments(args): quit() - - deepspeed.init_distributed() + load_io_ops(args) model, model_name, ckpt_name = get_model(args.model) if args.half: model = model.half() @@ -72,7 +68,6 @@ def main(): else: ckpt_state = {'model': model} run(ckpt_state, model_name, ckpt_name, args) - dist.destroy_process_group() if __name__ == "__main__": diff --git a/deepnvme/model_checkpoint/torch_save_tensor.py b/deepnvme/model_checkpoint/torch_save_tensor.py index 4c73a3b2a..55e5e4544 100644 --- a/deepnvme/model_checkpoint/torch_save_tensor.py +++ b/deepnvme/model_checkpoint/torch_save_tensor.py @@ -2,11 +2,10 @@ import argparse import torch import os -from torch_save_utils import PINNED_BUFFER_MB +from torch_save_utils import PINNED_BUFFER_MB, load_io_ops from torch_save_utils import test_save, test_ds_mock_save, test_ds_py_save, test_ds_aio_fast_save, test_ds_gds_fast_save import deepspeed from deepspeed.accelerator import get_accelerator -import deepspeed.comm as dist import os def run(args): @@ -28,8 +27,6 @@ def run(args): continue file = os.path.join(args.folder, f'{tag}_{args.mb_size}MB.pt') print(f'checkpoint file = {file}') - if os.path.isfile(file): - os.remove(file) st = time.time() write_sec = fn(file, buffer, args) gb_per_sec = args.mb_size / (1024.0 * write_sec) @@ -53,9 +50,9 @@ def parse_arguments(): default=None, required=True, help='Size of tensor to save in MB.') - parser.add_argument('--legacy', + parser.add_argument('--zipfile', action='store_true', - help='Use torch legacy save format') + help='Use torch zipfile save format') parser.add_argument('--gpu', action='store_true', help='Use gpu tensors.') @@ -71,10 +68,6 @@ def parse_arguments(): parser.add_argument('--single_io_buffer', action='store_true', help='Disable double buffering of i/o buffer.') - parser.add_argument('--local_rank', - type=int, - default=0, - help='Local rank' ) args = parser.parse_args() print(f'args = {args}') @@ -89,9 +82,8 @@ def main(): if not os.path.exists(args.folder): print(f'Invalid folder: {args.folder}') quit() - deepspeed.init_distributed() + load_io_ops(args) run(args) - dist.destroy_process_group() if __name__ == "__main__": diff --git a/deepnvme/model_checkpoint/torch_save_utils.py b/deepnvme/model_checkpoint/torch_save_utils.py index cf5f2bba5..56498da0d 100644 --- a/deepnvme/model_checkpoint/torch_save_utils.py +++ b/deepnvme/model_checkpoint/torch_save_utils.py @@ -13,6 +13,12 @@ AIO_OVERLAP_EVENTS = False PINNED_BUFFER_MB = 64 +def load_io_ops(args): + if AsyncIOBuilder().is_compatible(): + AsyncIOBuilder().load(verbose=False) + if args.gpu and GDSBuilder().is_compatible(): + GDSBuilder().load(verbose=False) + def _get_aio_handle(): h = AsyncIOBuilder().load(verbose=False).aio_handle(block_size=AIO_BLOCK_SIZE, @@ -34,7 +40,7 @@ def test_save(file, buffer, args): st = time.time() torch.save(f=file, obj=buffer, - _use_new_zipfile_serialization=not args.legacy) + _use_new_zipfile_serialization=args.zipfile) return time.time() - st @@ -43,7 +49,7 @@ def test_ds_mock_save(file, buffer, args): ds_mock_writer = MockFileWriter(file) torch.save(f=ds_mock_writer, obj=buffer, - _use_new_zipfile_serialization=not args.legacy) + _use_new_zipfile_serialization=args.zipfile) ds_mock_writer.close() # Force flush to storage write_sec = time.time() - st if not args.no_statistics: @@ -56,7 +62,7 @@ def test_ds_py_save(file, buffer, args): ds_py_writer = PyFileWriter(file) torch.save(f=ds_py_writer, obj=buffer, - _use_new_zipfile_serialization=not args.legacy) + _use_new_zipfile_serialization=args.zipfile) ds_py_writer.close() # Force flush to storage write_sec = time.time() - st if not args.no_statistics: @@ -96,7 +102,7 @@ def _test_ds_fast_save(file, buffer, args, use_gds): config=fast_writer_config) torch.save(f=ds_fast_writer, obj=buffer, - _use_new_zipfile_serialization=not args.legacy) + _use_new_zipfile_serialization=args.zipfile) ds_fast_writer.close() # Force flush to storage write_sec = time.time() - st if not args.no_statistics: