3131logging .getLogger ("p2ch14.dsets" ).setLevel (logging .WARNING )
3232
3333def print_confusion (label , confusions , do_mal ):
34+ row_labels = ['Non-Nodules' , 'Benign' , 'Malignant' ]
35+
3436 if do_mal :
35- col_labels = ['' , 'Complete Miss' , 'Filtered' , 'Benign' , 'Malignant' ]
36- row_labels = ['Non-Nodules' , 'Benign' , 'Malignant' ]
37+ col_labels = ['' , 'Complete Miss' , 'Filtered Out' , 'Pred. Benign' , 'Pred. Malignant' ]
3738 else :
38- col_labels = ['' , 'Complete Miss' , 'Filtered' , 'Detected' ]
39- row_labels = ['Non-Nodules' , 'Nodules' ]
40- confusions [- 2 ] += confusions [- 1 ]
39+ col_labels = ['' , 'Complete Miss' , 'Filtered Out' , 'Pred. Nodule' ]
4140 confusions [:, - 2 ] += confusions [:, - 1 ]
42- confusions = confusions [:- 1 , :- 1 ]
43- cell_width = 14
41+ confusions = confusions [:, :- 1 ]
42+ cell_width = 16
4443 f = '{:>' + str (cell_width ) + '}'
4544 print (label )
4645 print (' | ' .join ([f .format (s ) for s in col_labels ]))
@@ -72,7 +71,7 @@ def match_and_score(detections, truth, threshold=0.5, threshold_mal=0.5):
7271 confusion = np .zeros ((3 , 4 ), dtype = np .int )
7372 if len (detected_xyz ) == 0 :
7473 for tn in true_nodules :
75- confusiion [2 if tn .isMal_bool else 1 , 0 ] += 1
74+ confusion [2 if tn .isMal_bool else 1 , 0 ] += 1
7675 elif len (truth_xyz ) == 0 :
7776 for dc in detected_classes :
7877 confusion [0 , dc ] += 1
@@ -124,7 +123,7 @@ def __init__(self, sys_argv=None):
124123 parser .add_argument ('--segmentation-path' ,
125124 help = "Path to the saved segmentation model" ,
126125 nargs = '?' ,
127- default = None ,
126+ default = 'data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state' ,
128127 )
129128
130129 parser .add_argument ('--cls-model' ,
@@ -135,13 +134,14 @@ def __init__(self, sys_argv=None):
135134 parser .add_argument ('--classification-path' ,
136135 help = "Path to the saved classification model" ,
137136 nargs = '?' ,
138- default = None ,
137+ default = 'data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state' ,
139138 )
140139
141140 parser .add_argument ('--malignancy-model' ,
142141 help = "What to model class name to use for the malignancy classifier." ,
143142 action = 'store' ,
144- default = 'ModifiedLunaModel' ,
143+ default = 'LunaModel' ,
144+ # default='ModifiedLunaModel',
145145 )
146146 parser .add_argument ('--malignancy-path' ,
147147 help = "Path to the saved malignancy classification model" ,
@@ -303,7 +303,6 @@ def main(self):
303303 val_list = sorted (series_set & val_set )
304304
305305
306- candidateInfo_list = []
307306 candidateInfo_dict = getCandidateInfoDict ()
308307 series_iter = enumerateWithEstimate (
309308 val_list + train_list ,
@@ -314,10 +313,8 @@ def main(self):
314313 ct = getCt (series_uid )
315314 mask_a = self .segmentCt (ct , series_uid )
316315
317- candidateInfo_list = self .clusterSegmentationOutput (
318- series_uid ,
319- ct ,
320- mask_a ,
316+ candidateInfo_list = self .groupSegmentationOutput (
317+ series_uid , ct , mask_a
321318 )
322319 classifications_list = self .classifyCandidates (ct , candidateInfo_list )
323320
@@ -339,7 +336,6 @@ def main(self):
339336 print_confusion ("Total" , all_confusion , self .malignancy_model is not None )
340337
341338
342-
343339 def classifyCandidates (self , ct , candidateInfo_list ):
344340 cls_dl = self .initClassificationDl (candidateInfo_list )
345341 classifications_list = []
@@ -348,49 +344,50 @@ def classifyCandidates(self, ct, candidateInfo_list):
348344
349345 input_g = input_t .to (self .device )
350346 with torch .no_grad ():
351- _ , probability_g = self .cls_model (input_g )
347+ _ , probability_nodule_g = self .cls_model (input_g )
352348 if self .malignancy_model is not None :
353349 _ , probability_mal_g = self .malignancy_model (input_g )
354350 else :
355- probability_mal_g = torch .zeros_like (probability_g )
351+ probability_mal_g = torch .zeros_like (probability_nodule_g )
356352
357- for center_irc , prob , prob_mal in zip (center_list ,
358- probability_g [:,1 ].tolist (),
359- probability_mal_g [:,1 ].tolist ()
360- ):
353+ zip_iter = zip (
354+ center_list ,
355+ probability_nodule_g [:,1 ].tolist (),
356+ probability_mal_g [:,1 ].tolist (),
357+ )
358+ for center_irc , prob_nodule , prob_mal in zip_iter :
361359 center_xyz = irc2xyz (
362360 center_irc ,
363361 direction_a = ct .direction_a ,
364362 origin_xyz = ct .origin_xyz ,
365- vxSize_xyz = ct .vxSize_xyz )
366- classifications_list .append ((prob , prob_mal , center_xyz , center_irc ))
363+ vxSize_xyz = ct .vxSize_xyz ,
364+ )
365+ cls_tup = (prob_nodule , prob_mal , center_xyz , center_irc )
366+ classifications_list .append (cls_tup )
367367 return classifications_list
368368
369369 def segmentCt (self , ct , series_uid ):
370370 with torch .no_grad ():
371371 output_a = np .zeros_like (ct .hu_a , dtype = np .float32 )
372372 seg_dl = self .initSegmentationDl (series_uid )
373373 for batch_tup in seg_dl :
374- input_t = batch_tup [0 ]
375- ndx_list = batch_tup [4 ]
374+ input_t , label_t , series_list , slice_ndx_list = batch_tup
376375
377376 input_g = input_t .to (self .device )
378377 prediction_g = self .seg_model (input_g )
379378
380- for i , sample_ndx in enumerate (ndx_list ):
381- output_a [sample_ndx ] = prediction_g [i ].cpu ().numpy ()
379+ for i , slice_ndx in enumerate (slice_ndx_list ):
380+ output_a [slice_ndx ] = prediction_g [i ].cpu ().numpy ()
382381
383- # mask_a = output_a > 0.25
384382 mask_a = output_a > 0.5
385- # mask_a = morphology.binary_erosion(mask_a, iterations=1)
386- # mask_a = morphology.binary_dilation(mask_a, iterations=2)
383+ mask_a = morphology .binary_erosion (mask_a , iterations = 1 )
387384
388385 return mask_a
389386
390- def clusterSegmentationOutput (self , series_uid , ct , clean_a ):
387+ def groupSegmentationOutput (self , series_uid , ct , clean_a ):
391388 candidateLabel_a , candidate_count = measurements .label (clean_a )
392389 centerIrc_list = measurements .center_of_mass (
393- ct .hu_a + 1001 ,
390+ ct .hu_a . clip ( - 1000 , 1000 ) + 1001 ,
394391 labels = candidateLabel_a ,
395392 index = np .arange (1 , candidate_count + 1 ),
396393 )
0 commit comments