Skip to content

Commit 24ce179

Browse files
authored
Merge pull request #325 from roboflow/better-head-reinitialization
Better head reinitialization
2 parents cf06635 + d0311b6 commit 24ce179

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

rfdetr/detr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def train_from_config(self, config: TrainConfig, **kwargs):
137137
)
138138
self.model.reinitialize_detection_head(num_classes)
139139

140-
141140
train_config = config.dict()
142141
model_config = self.model_config.dict()
143142
model_config.pop("num_classes")

rfdetr/models/lwdetr.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,19 @@ def __init__(self,
103103
self._export = False
104104

105105
def reinitialize_detection_head(self, num_classes):
106-
# Create new classification head
107-
del self.class_embed
108-
self.add_module("class_embed", nn.Linear(self.transformer.d_model, num_classes))
106+
base = self.class_embed.weight.shape[0]
107+
num_repeats = int(math.ceil(num_classes / base))
108+
self.class_embed.weight.data = self.class_embed.weight.data.repeat(num_repeats, 1)
109+
self.class_embed.weight.data = self.class_embed.weight.data[:num_classes]
110+
self.class_embed.bias.data = self.class_embed.bias.data.repeat(num_repeats)
111+
self.class_embed.bias.data = self.class_embed.bias.data[:num_classes]
109112

110-
# Initialize with focal loss bias adjustment
111-
prior_prob = 0.01
112-
bias_value = -math.log((1 - prior_prob) / prior_prob)
113-
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
114-
115113
if self.two_stage:
116-
del self.transformer.enc_out_class_embed
117-
self.transformer.add_module("enc_out_class_embed", nn.ModuleList(
118-
[copy.deepcopy(self.class_embed) for _ in range(self.group_detr)]))
119-
114+
for enc_out_class_embed in self.transformer.enc_out_class_embed:
115+
enc_out_class_embed.weight.data = enc_out_class_embed.weight.data.repeat(num_repeats, 1)
116+
enc_out_class_embed.weight.data = enc_out_class_embed.weight.data[:num_classes]
117+
enc_out_class_embed.bias.data = enc_out_class_embed.bias.data.repeat(num_repeats)
118+
enc_out_class_embed.bias.data = enc_out_class_embed.bias.data[:num_classes]
120119

121120
def export(self):
122121
self._export = True

0 commit comments

Comments
 (0)