@@ -724,9 +724,30 @@ def load_from_hdf5(self, file, index, task=None, func=None):
724724 if task is None :
725725 task = self .name
726726 dset = file ['tasks' ][task ]
727- if not np .all (dset .attrs ['grid_space' ]):
728- raise ValueError ("Can only load data from grid space" )
729- self .load_from_global_grid_data (dset , pre_slices = (index ,), func = func )
727+ if np .all (dset .attrs ['grid_space' ]):
728+ self .load_from_global_grid_data (dset , pre_slices = (index ,), func = func )
729+ elif np .all (~ dset .attrs ['grid_space' ]):
730+ self .load_from_global_coeff_data (dset , pre_slices = (index ,), func = func )
731+ else :
732+ raise ValueError ("Can only load global data from pure grid or coeff space" )
733+
734+ def load_from_global_coeff_data (self , global_data , pre_slices = tuple (), func = None ):
735+ """Load local coeff data from array-like global coeff data."""
736+ dim = self .dist .dim
737+ layout = self .dist .coeff_layout
738+ # Check shapes
739+ data_shape = global_data .shape [- dim :]
740+ self_shape = layout .global_shape (self .domain , scales = 1 )
741+ if data_shape != self_shape :
742+ raise ValueError ("Cannot change global shape when loading coeff data." )
743+ # Extract local data from global data
744+ component_slices = tuple (slice (None ) for cs in self .tensorsig )
745+ spatial_slices = layout .slices (self .domain , scales = 1 )
746+ local_slices = pre_slices + component_slices + spatial_slices
747+ if func is None :
748+ self [layout ] = global_data [local_slices ]
749+ else :
750+ self [layout ] = func (global_data [local_slices ])
730751
731752 def load_from_global_grid_data (self , global_data , pre_slices = tuple (), func = None ):
732753 """Load local grid data from array-like global grid data."""
0 commit comments