Skip to content

Commit 5faccb5

Browse files
committed
Get input image shape statically in the model construction function
Instead of using input.get_shape() fonction We can pass input image shape to yolo_v3 and yolo_v3_tiny constructor function. Since the input placeholder shape is statically defined (to None, size, size, 3) We can have access to this 'size' when constructing the yolo_v3 or yolo_v3_tiny models. This if more efficient for inference. This commit should not break any anterior codes since it is only adding 1 optional argument to model constructor functions
1 parent 136fb66 commit 5faccb5

File tree

5 files changed

+12
-11
lines changed

5 files changed

+12
-11
lines changed

convert_weights.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
'spp', False, 'Use SPP version of YOLOv3')
2222
tf.app.flags.DEFINE_string(
2323
'ckpt_file', './saved_model/model.ckpt', 'Chceckpoint file')
24+
tf.app.flags.DEFINE_integer(
25+
'size', 416, 'Image size')
2426

2527

2628
def main(argv=None):
@@ -39,7 +41,8 @@ def main(argv=None):
3941

4042
with tf.variable_scope('detector'):
4143
detections = model(inputs, len(classes),
42-
data_format=FLAGS.data_format)
44+
data_format=FLAGS.data_format,
45+
img_size=[FLAGS.size, FLAGS.size])
4346
load_ops = load_weights(tf.global_variables(
4447
scope='detector'), FLAGS.weights_file)
4548

convert_weights_pb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'tiny', False, 'Use tiny version of YOLOv3')
2424
tf.app.flags.DEFINE_bool(
2525
'spp', False, 'Use SPP version of YOLOv3')
26+
2627
tf.app.flags.DEFINE_integer(
2728
'size', 416, 'Image size')
2829

@@ -42,7 +43,7 @@ def main(argv=None):
4243
inputs = tf.placeholder(tf.float32, [None, FLAGS.size, FLAGS.size, 3], "inputs")
4344

4445
with tf.variable_scope('detector'):
45-
detections = model(inputs, len(classes), data_format=FLAGS.data_format)
46+
detections = model(inputs, len(classes), data_format=FLAGS.data_format, img_size=[FLAGS.size, FLAGS.size])
4647
load_ops = load_weights(tf.global_variables(scope='detector'), FLAGS.weights_file)
4748

4849
# Sets the output nodes in the current session

utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def get_boxes_and_inputs(model, num_classes, size, data_format):
2020

2121
with tf.variable_scope('detector'):
2222
detections = model(inputs, num_classes,
23-
data_format=data_format)
23+
data_format=data_format,
24+
img_size=[size, size])
2425

2526
boxes = detections_boxes(detections)
2627

yolo_v3.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _upsample(inputs, out_shape, data_format='NCHW'):
200200
return inputs
201201

202202

203-
def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False):
203+
def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False, img_size=[416, 416]):
204204
"""
205205
Creates YOLO v3 model.
206206
@@ -213,8 +213,6 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
213213
:param with_spp: whether or not is using spp layer.
214214
:return:
215215
"""
216-
# it will be needed later on
217-
img_size = inputs.get_shape().as_list()[1:3]
218216

219217
# transpose the inputs to NCHW
220218
if data_format == 'NCHW':
@@ -277,7 +275,7 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
277275
return detections
278276

279277

280-
def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
278+
def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]):
281279
"""
282280
Creates YOLO v3 with SPP model.
283281
@@ -289,4 +287,4 @@ def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reus
289287
:param reuse: whether or not the network and its variables should be reused.
290288
:return:
291289
"""
292-
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True)
290+
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True, img_size=img_size)

yolo_v3_tiny.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
(81, 82), (135, 169), (344, 319)]
1616

1717

18-
def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
18+
def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]):
1919
"""
2020
Creates YOLO v3 tiny model.
2121
@@ -27,8 +27,6 @@ def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reu
2727
:param reuse: whether or not the network and its variables should be reused.
2828
:return:
2929
"""
30-
# it will be needed later on
31-
img_size = inputs.get_shape().as_list()[1:3]
3230

3331
# transpose the inputs to NCHW
3432
if data_format == 'NCHW':

0 commit comments

Comments
 (0)