@@ -518,7 +518,7 @@ def supports_hdf5() -> bool:
518518 def load_hdf5 (
519519 path : str ,
520520 dataset : str ,
521- dtype : datatype = types . float32 ,
521+ dtype : Optional [ datatype ] = None ,
522522 slices : Optional [Tuple [Optional [slice ], ...]] = None ,
523523 split : Optional [int ] = None ,
524524 device : Optional [str ] = None ,
@@ -534,7 +534,7 @@ def load_hdf5(
534534 dataset : str
535535 Name of the dataset to be read.
536536 dtype : datatype, optional
537- Data type of the resulting array.
537+ Data type of the resulting array, defaults to the loaded datasets type .
538538 slices : tuple of slice objects, optional
539539 Load only the specified slices of the dataset.
540540 split : int or None, optional
@@ -626,8 +626,6 @@ def load_hdf5(
626626 elif split is not None and not isinstance (split , int ):
627627 raise TypeError (f"split must be None or int, not { type (split )} " )
628628
629- # infer the type and communicator for the loaded array
630- dtype = types .canonical_heat_type (dtype )
631629 # determine the comm and device the data will be placed on
632630 device = devices .sanitize_device (device )
633631 comm = sanitize_comm (comm )
@@ -638,6 +636,9 @@ def load_hdf5(
638636 gshape = data .shape
639637 new_gshape = tuple ()
640638 offsets = [0 ] * len (gshape )
639+ if dtype is None :
640+ dtype = data .dtype
641+ dtype = types .canonical_heat_type (dtype )
641642 if slices is not None :
642643 for i in range (len (gshape )):
643644 if i < len (slices ) and slices [i ]:
@@ -688,7 +689,12 @@ def load_hdf5(
688689 return DNDarray (data , gshape , dtype , split , device , comm , balanced )
689690
690691 def save_hdf5 (
691- data : DNDarray , path : str , dataset : str , mode : str = "w" , ** kwargs : Dict [str , object ]
692+ data : DNDarray ,
693+ path : str ,
694+ dataset : str ,
695+ mode : str = "w" ,
696+ dtype : Optional [datatype ] = None ,
697+ ** kwargs : Dict [str , object ],
692698 ):
693699 """
694700 Saves ``data`` to an HDF5 file. Attempts to utilize parallel I/O if possible.
@@ -703,6 +709,8 @@ def save_hdf5(
703709 Name of the dataset the data is saved to.
704710 mode : str, optional
705711 File access mode, one of ``'w', 'a', 'r+'``
712+ dtype : datatype, optional
713+ Data type of the saved data
706714 kwargs : dict, optional
707715 Additional arguments passed to the created dataset.
708716
@@ -733,16 +741,23 @@ def save_hdf5(
733741 is_split = data .split is not None
734742 _ , _ , slices = data .comm .chunk (data .gshape , data .split if is_split else 0 )
735743
744+ if dtype is None :
745+ dtype = data .dtype
746+ elif type (dtype ) == torch .dtype :
747+ dtype = str (dtype ).split ("." )[- 1 ]
748+ if type (dtype ) is not str :
749+ dtype = dtype .__name__
750+
736751 # attempt to perform parallel I/O if possible
737752 if h5py .get_config ().mpi :
738753 with h5py .File (path , mode , driver = "mpio" , comm = data .comm .handle ) as handle :
739- dset = handle .create_dataset (dataset , data .shape , ** kwargs )
754+ dset = handle .create_dataset (dataset , data .shape , dtype = dtype , ** kwargs )
740755 dset [slices ] = data .larray .cpu () if is_split else data .larray [slices ].cpu ()
741756
742757 # otherwise a single rank only write is performed in case of local data (i.e. no split)
743758 elif data .comm .rank == 0 :
744759 with h5py .File (path , mode ) as handle :
745- dset = handle .create_dataset (dataset , data .shape , ** kwargs )
760+ dset = handle .create_dataset (dataset , data .shape , dtype = dtype , ** kwargs )
746761 if is_split :
747762 dset [slices ] = data .larray .cpu ()
748763 else :
@@ -764,7 +779,7 @@ def save_hdf5(
764779 next_rank = (data .comm .rank + 1 ) % data .comm .size
765780 data .comm .Isend ([None , 0 , MPI .INT ], dest = next_rank )
766781
767- DNDarray .save_hdf5 = lambda self , path , dataset , mode = "w" , ** kwargs : save_hdf5 (
782+ DNDarray .save_hdf5 = lambda self , path , dataset , mode = "w" , dtype = None , ** kwargs : save_hdf5 (
768783 self , path , dataset , mode , ** kwargs
769784 )
770785 DNDarray .save_hdf5 .__doc__ = save_hdf5 .__doc__
0 commit comments