@@ -103,20 +103,19 @@ def __init__(self,
103103 self ._export = False
104104
105105 def reinitialize_detection_head (self , num_classes ):
106- # Create new classification head
107- del self .class_embed
108- self .add_module ("class_embed" , nn .Linear (self .transformer .d_model , num_classes ))
106+ base = self .class_embed .weight .shape [0 ]
107+ num_repeats = int (math .ceil (num_classes / base ))
108+ self .class_embed .weight .data = self .class_embed .weight .data .repeat (num_repeats , 1 )
109+ self .class_embed .weight .data = self .class_embed .weight .data [:num_classes ]
110+ self .class_embed .bias .data = self .class_embed .bias .data .repeat (num_repeats )
111+ self .class_embed .bias .data = self .class_embed .bias .data [:num_classes ]
109112
110- # Initialize with focal loss bias adjustment
111- prior_prob = 0.01
112- bias_value = - math .log ((1 - prior_prob ) / prior_prob )
113- self .class_embed .bias .data = torch .ones (num_classes ) * bias_value
114-
115113 if self .two_stage :
116- del self .transformer .enc_out_class_embed
117- self .transformer .add_module ("enc_out_class_embed" , nn .ModuleList (
118- [copy .deepcopy (self .class_embed ) for _ in range (self .group_detr )]))
119-
114+ for enc_out_class_embed in self .transformer .enc_out_class_embed :
115+ enc_out_class_embed .weight .data = enc_out_class_embed .weight .data .repeat (num_repeats , 1 )
116+ enc_out_class_embed .weight .data = enc_out_class_embed .weight .data [:num_classes ]
117+ enc_out_class_embed .bias .data = enc_out_class_embed .bias .data .repeat (num_repeats )
118+ enc_out_class_embed .bias .data = enc_out_class_embed .bias .data [:num_classes ]
120119
121120 def export (self ):
122121 self ._export = True
0 commit comments