Skip to content

Commit 678cd47

Browse files
Set correct dtype when loading and saving hdf5 (#2014)
* fixed load_hdf5 * fixed save_hdf5 * fixed different behavior in tests * test torch dtype for save_hdf5 --------- Co-authored-by: Claudia Comito <[email protected]>
1 parent 4cf0146 commit 678cd47

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

heat/core/io.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def supports_hdf5() -> bool:
518518
def load_hdf5(
519519
path: str,
520520
dataset: str,
521-
dtype: datatype = types.float32,
521+
dtype: Optional[datatype] = None,
522522
slices: Optional[Tuple[Optional[slice], ...]] = None,
523523
split: Optional[int] = None,
524524
device: Optional[str] = None,
@@ -534,7 +534,7 @@ def load_hdf5(
534534
dataset : str
535535
Name of the dataset to be read.
536536
dtype : datatype, optional
537-
Data type of the resulting array.
537+
Data type of the resulting array, defaults to the loaded datasets type.
538538
slices : tuple of slice objects, optional
539539
Load only the specified slices of the dataset.
540540
split : int or None, optional
@@ -626,8 +626,6 @@ def load_hdf5(
626626
elif split is not None and not isinstance(split, int):
627627
raise TypeError(f"split must be None or int, not {type(split)}")
628628

629-
# infer the type and communicator for the loaded array
630-
dtype = types.canonical_heat_type(dtype)
631629
# determine the comm and device the data will be placed on
632630
device = devices.sanitize_device(device)
633631
comm = sanitize_comm(comm)
@@ -638,6 +636,9 @@ def load_hdf5(
638636
gshape = data.shape
639637
new_gshape = tuple()
640638
offsets = [0] * len(gshape)
639+
if dtype is None:
640+
dtype = data.dtype
641+
dtype = types.canonical_heat_type(dtype)
641642
if slices is not None:
642643
for i in range(len(gshape)):
643644
if i < len(slices) and slices[i]:
@@ -688,7 +689,12 @@ def load_hdf5(
688689
return DNDarray(data, gshape, dtype, split, device, comm, balanced)
689690

690691
def save_hdf5(
691-
data: DNDarray, path: str, dataset: str, mode: str = "w", **kwargs: Dict[str, object]
692+
data: DNDarray,
693+
path: str,
694+
dataset: str,
695+
mode: str = "w",
696+
dtype: Optional[datatype] = None,
697+
**kwargs: Dict[str, object],
692698
):
693699
"""
694700
Saves ``data`` to an HDF5 file. Attempts to utilize parallel I/O if possible.
@@ -703,6 +709,8 @@ def save_hdf5(
703709
Name of the dataset the data is saved to.
704710
mode : str, optional
705711
File access mode, one of ``'w', 'a', 'r+'``
712+
dtype : datatype, optional
713+
Data type of the saved data
706714
kwargs : dict, optional
707715
Additional arguments passed to the created dataset.
708716
@@ -733,16 +741,23 @@ def save_hdf5(
733741
is_split = data.split is not None
734742
_, _, slices = data.comm.chunk(data.gshape, data.split if is_split else 0)
735743

744+
if dtype is None:
745+
dtype = data.dtype
746+
elif type(dtype) == torch.dtype:
747+
dtype = str(dtype).split(".")[-1]
748+
if type(dtype) is not str:
749+
dtype = dtype.__name__
750+
736751
# attempt to perform parallel I/O if possible
737752
if h5py.get_config().mpi:
738753
with h5py.File(path, mode, driver="mpio", comm=data.comm.handle) as handle:
739-
dset = handle.create_dataset(dataset, data.shape, **kwargs)
754+
dset = handle.create_dataset(dataset, data.shape, dtype=dtype, **kwargs)
740755
dset[slices] = data.larray.cpu() if is_split else data.larray[slices].cpu()
741756

742757
# otherwise a single rank only write is performed in case of local data (i.e. no split)
743758
elif data.comm.rank == 0:
744759
with h5py.File(path, mode) as handle:
745-
dset = handle.create_dataset(dataset, data.shape, **kwargs)
760+
dset = handle.create_dataset(dataset, data.shape, dtype=dtype, **kwargs)
746761
if is_split:
747762
dset[slices] = data.larray.cpu()
748763
else:
@@ -764,7 +779,7 @@ def save_hdf5(
764779
next_rank = (data.comm.rank + 1) % data.comm.size
765780
data.comm.Isend([None, 0, MPI.INT], dest=next_rank)
766781

767-
DNDarray.save_hdf5 = lambda self, path, dataset, mode="w", **kwargs: save_hdf5(
782+
DNDarray.save_hdf5 = lambda self, path, dataset, mode="w", dtype=None, **kwargs: save_hdf5(
768783
self, path, dataset, mode, **kwargs
769784
)
770785
DNDarray.save_hdf5.__doc__ = save_hdf5.__doc__

heat/core/tests/test_io.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_size_from_slice(self):
117117
def test_load(self):
118118
# HDF5
119119
if ht.io.supports_hdf5():
120-
iris = ht.load(self.HDF5_PATH, dataset="data")
120+
iris = ht.load(self.HDF5_PATH, dataset="data", dtype=ht.float32)
121121
self.assertIsInstance(iris, ht.DNDarray)
122122
# shape invariant
123123
self.assertEqual(iris.shape, self.IRIS.shape)
@@ -602,7 +602,7 @@ def test_load_hdf5(self):
602602
self.skipTest("Requires HDF5")
603603

604604
# default parameters
605-
iris = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET)
605+
iris = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, dtype=ht.float32)
606606
self.assertIsInstance(iris, ht.DNDarray)
607607
self.assertEqual(iris.shape, self.IRIS.shape)
608608
self.assertEqual(iris.dtype, ht.float32)
@@ -613,13 +613,13 @@ def test_load_hdf5(self):
613613
iris = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, split=0)
614614
self.assertIsInstance(iris, ht.DNDarray)
615615
self.assertEqual(iris.shape, self.IRIS.shape)
616-
self.assertEqual(iris.dtype, ht.float32)
616+
self.assertEqual(iris.dtype, ht.float64)
617617
lshape = iris.lshape
618618
self.assertLessEqual(lshape[0], self.IRIS.shape[0])
619619
self.assertEqual(lshape[1], self.IRIS.shape[1])
620620

621621
# negative split axis
622-
iris = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, split=-1)
622+
iris = ht.load_hdf5(self.HDF5_PATH, self.HDF5_DATASET, split=-1, dtype=ht.float32)
623623
self.assertIsInstance(iris, ht.DNDarray)
624624
self.assertEqual(iris.shape, self.IRIS.shape)
625625
self.assertEqual(iris.dtype, ht.float32)
@@ -661,7 +661,7 @@ def test_save_hdf5(self):
661661
# local unsplit data
662662
local_data = ht.arange(100)
663663
ht.save_hdf5(
664-
local_data, self.HDF5_OUT_PATH, self.HDF5_DATASET, dtype=local_data.dtype.char()
664+
local_data, self.HDF5_OUT_PATH, self.HDF5_DATASET, dtype=torch.int32
665665
)
666666
if local_data.comm.rank == 0:
667667
with ht.io.h5py.File(self.HDF5_OUT_PATH, "r") as handle:
@@ -673,7 +673,7 @@ def test_save_hdf5(self):
673673
# distributed data range
674674
split_data = ht.arange(100, split=0)
675675
ht.save_hdf5(
676-
split_data, self.HDF5_OUT_PATH, self.HDF5_DATASET, dtype=split_data.dtype.char()
676+
split_data, self.HDF5_OUT_PATH, self.HDF5_DATASET
677677
)
678678
if split_data.comm.rank == 0:
679679
with ht.io.h5py.File(self.HDF5_OUT_PATH, "r") as handle:

0 commit comments

Comments
 (0)