@@ -151,7 +151,7 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None,
151151 self .net .load_model (self .pretrained_model , device = self .device )
152152
153153
154- def eval (self , x , batch_size = 8 , resample = None , channels = None , channel_axis = None ,
154+ def eval (self , x , batch_size = 8 , resample = True , channels = None , channel_axis = None ,
155155 z_axis = None , normalize = True , invert = False , rescale = None , diameter = None ,
156156 flow_threshold = 0.4 , cellprob_threshold = 0.0 , do_3D = False , anisotropy = None ,
157157 flow3D_smooth = 0 , stitch_threshold = 0.0 ,
@@ -165,7 +165,6 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
165165 batch_size (int, optional): number of 256x256 patches to run simultaneously on the GPU
166166 (can make smaller or bigger depending on GPU memory usage). Defaults to 64.
167167 resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries).
168- deprecated in v4.0.1+, resample is not used
169168 channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
170169 if None, channels dimension is attempted to be automatically determined. Defaults to None.
171170 z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
@@ -327,7 +326,9 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
327326
328327 if resample :
329328 # upsample flows before computing them:
330- raise NotImplementedError
329+ dP = self ._resize_gradients (dP , to_y_size = Ly_0 , to_x_size = Lx_0 , to_z_size = Lz_0 )
330+ cellprob = self ._resize_cellprob (cellprob , to_x_size = Lx_0 , to_y_size = Ly_0 , to_z_size = Lz_0 )
331+
331332
332333 if compute_masks :
333334 niter0 = 200
@@ -343,6 +344,10 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
343344
344345 # undo resizing:
345346 if image_scaling is not None or anisotropy is not None :
347+
348+ dP = self ._resize_gradients (dP , to_y_size = Ly_0 , to_x_size = Lx_0 , to_z_size = Lz_0 ) # works for 2 or 3D:
349+ cellprob = self ._resize_cellprob (cellprob , to_x_size = Lx_0 , to_y_size = Ly_0 , to_z_size = Lz_0 )
350+
346351 if do_3D :
347352 if compute_masks :
348353 # Rescale xy then xz:
@@ -351,29 +356,96 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
351356 masks = transforms .resize_image (masks , Ly = Lz_0 , Lx = Lx_0 , no_channels = True , interpolation = cv2 .INTER_NEAREST )
352357 masks = masks .transpose (1 , 0 , 2 )
353358
354- # cellprob is the same
355- cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
356- cellprob = cellprob .transpose (1 , 0 , 2 )
357- cellprob = transforms .resize_image (cellprob , Ly = Lz_0 , Lx = Lx_0 , no_channels = True )
358- cellprob = cellprob .transpose (1 , 0 , 2 )
359-
360- # dP has gradients that can be treated as channels:
361- dP = dP .transpose (1 , 2 , 3 , 0 ) # move gradients last:
362- dP = transforms .resize_image (dP , Ly = Ly_0 , Lx = Lx_0 , no_channels = False )
363- dP = dP .transpose (1 , 0 , 2 , 3 ) # switch axes to resize again
364- dP = transforms .resize_image (dP , Ly = Lz_0 , Lx = Lx_0 , no_channels = False )
365- dP = dP .transpose (3 , 1 , 0 , 2 ) # undo transposition
366-
367359 else :
368360 # 2D or 3D stitching case:
369361 if compute_masks :
370362 masks = transforms .resize_image (masks , Ly = Ly_0 , Lx = Lx_0 , no_channels = True , interpolation = cv2 .INTER_NEAREST )
371- cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
372- dP = np .moveaxis (dP , 0 , - 1 ) # Put gradients last
373- dP = transforms .resize_image (dP , Ly = Ly_0 , Lx = Lx_0 , no_channels = False )
374- dP = np .moveaxis (dP , - 1 , 0 ) # Put gradients first
375363
376364 return masks , [plot .dx_to_circ (dP ), dP , cellprob ], styles
365+
366+
367+ def _resize_cellprob (self , prob : np .ndarray , to_y_size : int , to_x_size : int , to_z_size : int = None ) -> np .ndarray :
368+ """
369+ Resize cellprob array to specified dimensions for either 2D or 3D.
370+
371+ Parameters:
372+ prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
373+ to_y_size (int): The target size along the Y-axis.
374+ to_x_size (int): The target size along the X-axis.
375+ to_z_size (int, optional): The target size along the Z-axis. Required
376+ for 3D cellprobs.
377+
378+ Returns:
379+ numpy.ndarray: The resized cellprobs array with the same number of dimensions
380+ as the input.
381+
382+ Raises:
383+ ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
384+ """
385+ prob_shape = prob .shape
386+ prob = prob .squeeze ()
387+ squeeze_happened = prob .shape != prob_shape
388+ prob_shape = np .array (prob_shape )
389+
390+ if prob .ndim == 2 :
391+ # 2D case:
392+ prob = transforms .resize_image (prob , Ly = to_y_size , Lx = to_x_size , no_channels = True )
393+ if squeeze_happened :
394+ prob = np .expand_dims (prob , int (np .argwhere (prob_shape == 1 ))) # add back empty axis for compatibility
395+ elif prob .ndim == 3 :
396+ # 3D case:
397+ prob = transforms .resize_image (prob , Ly = to_y_size , Lx = to_x_size , no_channels = True )
398+ prob = prob .transpose (1 , 0 , 2 )
399+ prob = transforms .resize_image (prob , Ly = to_z_size , Lx = to_x_size , no_channels = True )
400+ prob = prob .transpose (1 , 0 , 2 )
401+ else :
402+ raise ValueError (f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: { prob .shape } ' )
403+
404+ return prob
405+
406+
407+ def _resize_gradients (self , grads : np .ndarray , to_y_size : int , to_x_size : int , to_z_size : int = None ) -> np .ndarray :
408+ """
409+ Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
410+
411+ Parameters:
412+ grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
413+ to_y_size (int): The target size along the Y-axis.
414+ to_x_size (int): The target size along the X-axis.
415+ to_z_size (int, optional): The target size along the Z-axis. Required
416+ for 3D gradients.
417+
418+ Returns:
419+ numpy.ndarray: The resized gradient array with the same number of dimensions
420+ as the input.
421+
422+ Raises:
423+ ValueError: If the input gradient array does not have 3 or 4 dimensions.
424+ """
425+ grads_shape = grads .shape
426+ grads = grads .squeeze ()
427+ squeeze_happened = grads .shape != grads_shape
428+ grads_shape = np .array (grads_shape )
429+
430+ if grads .ndim == 3 :
431+ # 2D case, with XY flows in 2 channels:
432+ grads = np .moveaxis (grads , 0 , - 1 ) # Put gradients last
433+ grads = transforms .resize_image (grads , Ly = to_y_size , Lx = to_x_size , no_channels = False )
434+ grads = np .moveaxis (grads , - 1 , 0 ) # Put gradients first
435+
436+ if squeeze_happened :
437+ grads = np .expand_dims (grads , int (np .argwhere (grads_shape == 1 ))) # add back empty axis for compatibility
438+ elif grads .ndim == 4 :
439+ # dP has gradients that can be treated as channels:
440+ grads = grads .transpose (1 , 2 , 3 , 0 ) # move gradients last:
441+ grads = transforms .resize_image (grads , Ly = to_y_size , Lx = to_x_size , no_channels = False )
442+ grads = grads .transpose (1 , 0 , 2 , 3 ) # switch axes to resize again
443+ grads = transforms .resize_image (grads , Ly = to_z_size , Lx = to_x_size , no_channels = False )
444+ grads = grads .transpose (3 , 1 , 0 , 2 ) # undo transposition
445+ else :
446+ raise ValueError (f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: { grads .shape } ' )
447+
448+ return grads
377449
378450
379451 def _run_net (self , x ,
0 commit comments