| import tensorflow as tf | |
| from absl import app, flags, logging | |
| from absl.flags import FLAGS | |
| from core.yolov4 import YOLO, decode, filter_boxes | |
| import core.utils as utils | |
| from core.config import cfg | |
| flags.DEFINE_string('weights', './data/yolov4.weights', 'path to weights file') | |
| flags.DEFINE_string('output', './checkpoints/yolov4-416', 'path to output') | |
| flags.DEFINE_boolean('tiny', False, 'is yolo-tiny or not') | |
| flags.DEFINE_integer('input_size', 416, 'define input size of export model') | |
| flags.DEFINE_float('score_thres', 0.2, 'define score threshold') | |
| flags.DEFINE_string('framework', 'tf', 'define what framework do you want to convert (tf, trt, tflite)') | |
| flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4') | |
| def save_tf(): | |
| STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS) | |
| input_layer = tf.keras.layers.Input([FLAGS.input_size, FLAGS.input_size, 3]) | |
| feature_maps = YOLO(input_layer, NUM_CLASS, FLAGS.model, FLAGS.tiny) | |
| bbox_tensors = [] | |
| prob_tensors = [] | |
| if FLAGS.tiny: | |
| for i, fm in enumerate(feature_maps): | |
| if i == 0: | |
| output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) | |
| else: | |
| output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) | |
| bbox_tensors.append(output_tensors[0]) | |
| prob_tensors.append(output_tensors[1]) | |
| else: | |
| for i, fm in enumerate(feature_maps): | |
| if i == 0: | |
| output_tensors = decode(fm, FLAGS.input_size // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) | |
| elif i == 1: | |
| output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) | |
| else: | |
| output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) | |
| bbox_tensors.append(output_tensors[0]) | |
| prob_tensors.append(output_tensors[1]) | |
| pred_bbox = tf.concat(bbox_tensors, axis=1) | |
| pred_prob = tf.concat(prob_tensors, axis=1) | |
| if FLAGS.framework == 'tflite': | |
| pred = (pred_bbox, pred_prob) | |
| else: | |
| boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=FLAGS.score_thres, input_shape=tf.constant([FLAGS.input_size, FLAGS.input_size])) | |
| pred = tf.concat([boxes, pred_conf], axis=-1) | |
| model = tf.keras.Model(input_layer, pred) | |
| utils.load_weights(model, FLAGS.weights, FLAGS.model, FLAGS.tiny) | |
| model.summary() | |
| model.save(FLAGS.output) | |
| def main(_argv): | |
| save_tf() | |
| if __name__ == '__main__': | |
| try: | |
| app.run(main) | |
| except SystemExit: | |
| pass | |