Skip to content

Commit 136fb66

Browse files
authored
Merge pull request #81 from ar90n/add-yolov3-spp-support
Add yolov3-spp support
2 parents 74b2f46 + 142150c commit 136fb66

File tree

5 files changed

+53
-6
lines changed

5 files changed

+53
-6
lines changed

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ To run demo type this in the command line:
1919
1. Download binary file with desired weights:
2020
1. Full weights: `wget https://pjreddie.com/media/files/yolov3.weights`
2121
1. Tiny weights: `wget https://pjreddie.com/media/files/yolov3-tiny.weights`
22+
1. SPP weights: `wget https://pjreddie.com/media/files/yolov3-spp.weights`
2223
2. Run `python ./convert_weights.py` and `python ./convert_weights_pb.py`
2324
3. Run `python ./demo.py --input_img <path-to-image> --output_img <name-of-output-image> --frozen_model <path-to-frozen-model>`
2425

@@ -33,7 +34,9 @@ To run demo type this in the command line:
3334
1. `NCHW` (gpu only) or `NHWC`
3435
4. `--tiny`
3536
1. Use yolov3-tiny
36-
5. `--ckpt_file`
37+
5. `--spp`
38+
1. Use yolov3-spp
39+
6. `--ckpt_file`
3740
1. Output checkpoint file
3841
2. convert_weights_pb.py:
3942
1. `--class_names`
@@ -44,7 +47,9 @@ To run demo type this in the command line:
4447
1. `NCHW` (gpu only) or `NHWC`
4548
4. `--tiny`
4649
1. Use yolov3-tiny
47-
5. `--output_graph`
50+
5. `--spp`
51+
1. Use yolov3-spp
52+
6. `--output_graph`
4853
1. Location to write the output .pb graph to
4954
3. demo.py
5055
1. `--class_names`
@@ -62,4 +67,4 @@ To run demo type this in the command line:
6267
7. `--iou_threshold`
6368
1. Desired iou threshold
6469
8. `--gpu_memory_fraction`
65-
1. Fraction of gpu memory to work with
70+
1. Fraction of gpu memory to work with

convert_weights.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
'data_format', 'NCHW', 'Data format: NCHW (gpu only) / NHWC')
1818
tf.app.flags.DEFINE_bool(
1919
'tiny', False, 'Use tiny version of YOLOv3')
20+
tf.app.flags.DEFINE_bool(
21+
'spp', False, 'Use SPP version of YOLOv3')
2022
tf.app.flags.DEFINE_string(
2123
'ckpt_file', './saved_model/model.ckpt', 'Chceckpoint file')
2224

2325

2426
def main(argv=None):
2527
if FLAGS.tiny:
2628
model = yolo_v3_tiny.yolo_v3_tiny
29+
elif FLAGS.spp:
30+
model = yolo_v3.yolo_v3_spp
2731
else:
2832
model = yolo_v3.yolo_v3
2933

convert_weights_pb.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
tf.app.flags.DEFINE_bool(
2323
'tiny', False, 'Use tiny version of YOLOv3')
24+
tf.app.flags.DEFINE_bool(
25+
'spp', False, 'Use SPP version of YOLOv3')
2426
tf.app.flags.DEFINE_integer(
2527
'size', 416, 'Image size')
2628

@@ -29,6 +31,8 @@
2931
def main(argv=None):
3032
if FLAGS.tiny:
3133
model = yolo_v3_tiny.yolo_v3_tiny
34+
elif FLAGS.spp:
35+
model = yolo_v3.yolo_v3_spp
3236
else:
3337
model = yolo_v3.yolo_v3
3438

demo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
'frozen_model', '', 'Frozen tensorflow protobuf model')
3030
tf.app.flags.DEFINE_bool(
3131
'tiny', False, 'Use tiny version of YOLOv3')
32+
tf.app.flags.DEFINE_bool(
33+
'spp', False, 'Use SPP version of YOLOv3')
3234

3335
tf.app.flags.DEFINE_integer(
3436
'size', 416, 'Image size')
@@ -71,6 +73,8 @@ def main(argv=None):
7173
else:
7274
if FLAGS.tiny:
7375
model = yolo_v3_tiny.yolo_v3_tiny
76+
elif FLAGS.spp:
77+
model = yolo_v3.yolo_v3_spp
7478
else:
7579
model = yolo_v3.yolo_v3
7680

yolo_v3.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def _darknet53_block(inputs, filters):
6363
return inputs
6464

6565

66+
def _spp_block(inputs, data_format='NCHW'):
67+
return tf.concat([slim.max_pool2d(inputs, 13, 1, 'SAME'),
68+
slim.max_pool2d(inputs, 9, 1, 'SAME'),
69+
slim.max_pool2d(inputs, 5, 1, 'SAME'),
70+
inputs],
71+
axis=1 if data_format == 'NCHW' else 3)
72+
73+
6674
@tf.contrib.framework.add_arg_scope
6775
def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs):
6876
"""
@@ -95,10 +103,15 @@ def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs):
95103
return padded_inputs
96104

97105

98-
def _yolo_block(inputs, filters):
106+
def _yolo_block(inputs, filters, data_format='NCHW', with_spp=False):
99107
inputs = _conv2d_fixed_padding(inputs, filters, 1)
100108
inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
101109
inputs = _conv2d_fixed_padding(inputs, filters, 1)
110+
111+
if with_spp:
112+
inputs = _spp_block(inputs, data_format)
113+
inputs = _conv2d_fixed_padding(inputs, filters, 1)
114+
102115
inputs = _conv2d_fixed_padding(inputs, filters * 2, 3)
103116
inputs = _conv2d_fixed_padding(inputs, filters, 1)
104117
route = inputs
@@ -187,7 +200,7 @@ def _upsample(inputs, out_shape, data_format='NCHW'):
187200
return inputs
188201

189202

190-
def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
203+
def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False):
191204
"""
192205
Creates YOLO v3 model.
193206
@@ -197,6 +210,7 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
197210
:param is_training: whether is training or not.
198211
:param data_format: data format NCHW or NHWC.
199212
:param reuse: whether or not the network and its variables should be reused.
213+
:param with_spp: whether or not is using spp layer.
200214
:return:
201215
"""
202216
# it will be needed later on
@@ -228,7 +242,8 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
228242
route_1, route_2, inputs = darknet53(inputs)
229243

230244
with tf.variable_scope('yolo-v3'):
231-
route, inputs = _yolo_block(inputs, 512)
245+
route, inputs = _yolo_block(inputs, 512, data_format, with_spp)
246+
232247
detect_1 = _detection_layer(
233248
inputs, num_classes, _ANCHORS[6:9], img_size, data_format)
234249
detect_1 = tf.identity(detect_1, name='detect_1')
@@ -260,3 +275,18 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
260275
detections = tf.concat([detect_1, detect_2, detect_3], axis=1)
261276
detections = tf.identity(detections, name='detections')
262277
return detections
278+
279+
280+
def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
281+
"""
282+
Creates YOLO v3 with SPP model.
283+
284+
:param inputs: a 4-D tensor of size [batch_size, height, width, channels].
285+
Dimension batch_size may be undefined. The channel order is RGB.
286+
:param num_classes: number of predicted classes.
287+
:param is_training: whether is training or not.
288+
:param data_format: data format NCHW or NHWC.
289+
:param reuse: whether or not the network and its variables should be reused.
290+
:return:
291+
"""
292+
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True)

0 commit comments

Comments
 (0)