1- # type: ignore
2- # ruff: noqa
3-
41# Random selected images for final result
52# Training: [tensor([28]), tensor([46]), tensor([60]), tensor([63]), tensor([90])]
63# Validation: [tensor([10]), tensor([18]), tensor([35]), tensor([57]), tensor([79]]]
1411# mask against GT if available, Difference of Mask with GT if available.
1512
1613from collections import OrderedDict
17- from typing import Any
14+ from typing import Any , Protocol
1815
1916import torch
20-
2117import wandb
18+
2219from tipi .abstractions import PipelineProcess
2320
2421
22+ class ProgressTaskCallable (Protocol ):
23+ """Protocol for functions decorated with progress_task."""
24+
25+ def __call__ (self , total : int ) -> None : ...
26+
27+
2528class ResultProcess (PipelineProcess ):
2629 """
2730 ResultProcess is a class that handles the logging of images and their corresponding masks during
@@ -76,26 +79,27 @@ def __init__(self, controller: Any, force: bool, selected_images: dict[str, list
7679 self .valset = self .datasets .segnet_dataset_val
7780 self .testset = self .datasets .segnet_dataset_test
7881
79- def execute (self ):
82+ def execute (self ) -> None :
8083 train_image_log = self ._get_log_train_images ()
8184 val_image_log = self ._get_log_val_images ()
8285 test_image_log = self ._get_log_test_images ()
8386
84- train_image_log (len (self .train_images_indices ))
87+ if self .train_images_indices :
88+ train_image_log (len (self .train_images_indices ))
8589 if self .datasets .val_available () and self .val_images_indices :
8690 val_image_log (len (self .val_images_indices ))
8791 if self .datasets .test_available () and self .test_images_indices :
8892 test_image_log (len (self .test_images_indices ))
8993
90- def _get_log_train_images (self ) -> callable :
94+ def _get_log_train_images (self ) -> ProgressTaskCallable :
9195 @self .progress_manager .progress_task ("result" , visible = False )
92- def _inner_log_image (total , task_id , progress ) :
96+ def _inner_log_image (total : int , task_id : int , progress : Any ) -> None :
9397 image_stack = []
9498 mask_stack = []
9599 pred_mask_stack = []
96100 for idx in range (total ):
97101 progress .advance (task_id )
98- selected_images = self .train_images_indices [idx ]
102+ selected_images = self .train_images_indices [idx ] # type: ignore[index]
99103 image , mask = self .trainset [selected_images ]
100104 image_stack .append (image )
101105 mask_stack .append (mask )
@@ -106,17 +110,17 @@ def _inner_log_image(total, task_id, progress):
106110 pred_masks = torch .cat (pred_mask_stack , 1 )
107111 self ._log_image (images , masks , pred_masks , "train" )
108112
109- return _inner_log_image
113+ return _inner_log_image # type: ignore[no-any-return]
110114
111- def _get_log_val_images (self ) -> callable :
115+ def _get_log_val_images (self ) -> ProgressTaskCallable :
112116 @self .progress_manager .progress_task ("result" , visible = False )
113- def _inner_log_image (total , task_id , progress ) :
117+ def _inner_log_image (total : int , task_id : int , progress : Any ) -> None :
114118 image_stack = []
115119 mask_stack = []
116120 pred_mask_stack = []
117121 for idx in range (total ):
118122 progress .advance (task_id )
119- selected_images = self .val_images_indices [idx ]
123+ selected_images = self .val_images_indices [idx ] # type: ignore[index]
120124 image , mask = self .valset [selected_images ]
121125 image_stack .append (image )
122126 mask_stack .append (mask )
@@ -127,17 +131,17 @@ def _inner_log_image(total, task_id, progress):
127131 pred_masks = torch .cat (pred_mask_stack , 1 )
128132 self ._log_image (images , masks , pred_masks , "val" )
129133
130- return _inner_log_image
134+ return _inner_log_image # type: ignore[no-any-return]
131135
132- def _get_log_test_images (self ) -> callable :
136+ def _get_log_test_images (self ) -> ProgressTaskCallable :
133137 @self .progress_manager .progress_task ("result" , visible = False )
134- def _inner_log_image (total , task_id , progress ) :
138+ def _inner_log_image (total : int , task_id : int , progress : Any ) -> None :
135139 image_stack = []
136140 mask_stack = []
137141 pred_mask_stack = []
138142 for idx in range (total ):
139143 progress .advance (task_id )
140- selected_images = self .test_images_indices [idx ]
144+ selected_images = self .test_images_indices [idx ] # type: ignore[index]
141145 image , mask = self .testset [selected_images ]
142146 image_stack .append (image )
143147 mask_stack .append (mask )
@@ -148,26 +152,26 @@ def _inner_log_image(total, task_id, progress):
148152 pred_masks = torch .cat (pred_mask_stack , 1 )
149153 self ._log_image (images , masks , pred_masks , "test" )
150154
151- return _inner_log_image
155+ return _inner_log_image # type: ignore[no-any-return]
152156
153- def _inference_model (self , image ) :
157+ def _inference_model (self , image : Any ) -> Any :
154158 self .model .eval ()
155159 with torch .no_grad ():
156160 return self .model (image .unsqueeze (0 ).to (self .device ))
157161
158- def _get_pred_mask (self , pred ) :
162+ def _get_pred_mask (self , pred : Any ) -> Any :
159163 if isinstance (pred , OrderedDict ):
160164 pred = pred ["out" ]
161165 return pred .argmax (dim = 1 ).squeeze (0 ).cpu ()
162166
163- def _get_mask_difference (self , mask , pred_mask ) :
167+ def _get_mask_difference (self , mask : Any , pred_mask : Any ) -> Any :
164168 mask [mask == 255 ] = 0
165169 mask_difference = mask - pred_mask
166170 if mask_difference .min () < 0 :
167171 mask_difference = mask_difference + mask_difference .min ().abs ()
168172 return mask_difference .to (torch .uint8 )
169173
170- def _log_image (self , image , mask , pred_mask , dataset ) :
174+ def _log_image (self , image : Any , mask : Any , pred_mask : Any , dataset : str ) -> None :
171175 class_labels = self ._swap_labels (self .datasets .data_container .classes )
172176 just_image = wandb .Image (image , caption = f"{ dataset } images" )
173177 image_with_mask = wandb .Image (
@@ -186,8 +190,8 @@ def _log_image(self, image, mask, pred_mask, dataset):
186190 wandb .log ({f"{ dataset } _images_with_mask" : image_with_mask })
187191 wandb .log ({f"{ dataset } _mask_difference" : mask_difference })
188192
189- def _swap_labels (self , labels ) :
193+ def _swap_labels (self , labels : dict [ Any , Any ]) -> dict [ Any , Any ] :
190194 return {v : k for k , v in labels .items ()}
191195
192- def skip (self ):
196+ def skip (self ) -> bool :
193197 return False
0 commit comments