@@ -63,6 +63,14 @@ def _darknet53_block(inputs, filters):
6363 return inputs
6464
6565
66+ def _spp_block (inputs , data_format = 'NCHW' ):
67+ return tf .concat ([slim .max_pool2d (inputs , 13 , 1 , 'SAME' ),
68+ slim .max_pool2d (inputs , 9 , 1 , 'SAME' ),
69+ slim .max_pool2d (inputs , 5 , 1 , 'SAME' ),
70+ inputs ],
71+ axis = 1 if data_format == 'NCHW' else 3 )
72+
73+
6674@tf .contrib .framework .add_arg_scope
6775def _fixed_padding (inputs , kernel_size , * args , mode = 'CONSTANT' , ** kwargs ):
6876 """
@@ -95,10 +103,15 @@ def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs):
95103 return padded_inputs
96104
97105
98- def _yolo_block (inputs , filters ):
106+ def _yolo_block (inputs , filters , data_format = 'NCHW' , with_spp = False ):
99107 inputs = _conv2d_fixed_padding (inputs , filters , 1 )
100108 inputs = _conv2d_fixed_padding (inputs , filters * 2 , 3 )
101109 inputs = _conv2d_fixed_padding (inputs , filters , 1 )
110+
111+ if with_spp :
112+ inputs = _spp_block (inputs , data_format )
113+ inputs = _conv2d_fixed_padding (inputs , filters , 1 )
114+
102115 inputs = _conv2d_fixed_padding (inputs , filters * 2 , 3 )
103116 inputs = _conv2d_fixed_padding (inputs , filters , 1 )
104117 route = inputs
@@ -187,7 +200,7 @@ def _upsample(inputs, out_shape, data_format='NCHW'):
187200 return inputs
188201
189202
190- def yolo_v3 (inputs , num_classes , is_training = False , data_format = 'NCHW' , reuse = False ):
203+ def yolo_v3 (inputs , num_classes , is_training = False , data_format = 'NCHW' , reuse = False , with_spp = False ):
191204 """
192205 Creates YOLO v3 model.
193206
@@ -197,6 +210,7 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
197210 :param is_training: whether is training or not.
198211 :param data_format: data format NCHW or NHWC.
199212 :param reuse: whether or not the network and its variables should be reused.
213+ :param with_spp: whether or not is using spp layer.
200214 :return:
201215 """
202216 # it will be needed later on
@@ -228,7 +242,8 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
228242 route_1 , route_2 , inputs = darknet53 (inputs )
229243
230244 with tf .variable_scope ('yolo-v3' ):
231- route , inputs = _yolo_block (inputs , 512 )
245+ route , inputs = _yolo_block (inputs , 512 , data_format , with_spp )
246+
232247 detect_1 = _detection_layer (
233248 inputs , num_classes , _ANCHORS [6 :9 ], img_size , data_format )
234249 detect_1 = tf .identity (detect_1 , name = 'detect_1' )
@@ -260,3 +275,18 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
260275 detections = tf .concat ([detect_1 , detect_2 , detect_3 ], axis = 1 )
261276 detections = tf .identity (detections , name = 'detections' )
262277 return detections
278+
279+
280+ def yolo_v3_spp (inputs , num_classes , is_training = False , data_format = 'NCHW' , reuse = False ):
281+ """
282+ Creates YOLO v3 with SPP model.
283+
284+ :param inputs: a 4-D tensor of size [batch_size, height, width, channels].
285+ Dimension batch_size may be undefined. The channel order is RGB.
286+ :param num_classes: number of predicted classes.
287+ :param is_training: whether is training or not.
288+ :param data_format: data format NCHW or NHWC.
289+ :param reuse: whether or not the network and its variables should be reused.
290+ :return:
291+ """
292+ return yolo_v3 (inputs , num_classes , is_training = is_training , data_format = data_format , reuse = reuse , with_spp = True )
0 commit comments