| import tensorflow as tf | |
| from absl import app, flags, logging | |
| from absl.flags import FLAGS | |
| import numpy as np | |
| import cv2 | |
| from core.yolov4 import YOLOv4, YOLOv3, YOLOv3_tiny, decode | |
| import core.utils as utils | |
| import os | |
| from core.config import cfg | |
| flags.DEFINE_string('weights', './checkpoints/yolov4-416', 'path to weights file') | |
| flags.DEFINE_string('output', './checkpoints/yolov4-416-fp32.tflite', 'path to output') | |
| flags.DEFINE_integer('input_size', 416, 'path to output') | |
| flags.DEFINE_string('quantize_mode', 'float32', 'quantize mode (int8, float16, float32)') | |
| flags.DEFINE_string('dataset', "/Volumes/Elements/data/coco_dataset/coco/5k.txt", 'path to dataset') | |
| def representative_data_gen(): | |
| fimage = open(FLAGS.dataset).read().split() | |
| for input_value in range(10): | |
| if os.path.exists(fimage[input_value]): | |
| original_image=cv2.imread(fimage[input_value]) | |
| original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) | |
| image_data = utils.image_preprocess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size]) | |
| img_in = image_data[np.newaxis, ...].astype(np.float32) | |
| print("calibration image {}".format(fimage[input_value])) | |
| yield [img_in] | |
| else: | |
| continue | |
| def save_tflite(): | |
| converter = tf.lite.TFLiteConverter.from_saved_model(FLAGS.weights) | |
| if FLAGS.quantize_mode == 'float16': | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.target_spec.supported_types = [tf.compat.v1.lite.constants.FLOAT16] | |
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] | |
| converter.allow_custom_ops = True | |
| elif FLAGS.quantize_mode == 'int8': | |
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] | |
| converter.allow_custom_ops = True | |
| converter.representative_dataset = representative_data_gen | |
| tflite_model = converter.convert() | |
| open(FLAGS.output, 'wb').write(tflite_model) | |
| logging.info("model saved to: {}".format(FLAGS.output)) | |
| def demo(): | |
| interpreter = tf.lite.Interpreter(model_path=FLAGS.output) | |
| interpreter.allocate_tensors() | |
| logging.info('tflite model loaded') | |
| input_details = interpreter.get_input_details() | |
| print(input_details) | |
| output_details = interpreter.get_output_details() | |
| print(output_details) | |
| input_shape = input_details[0]['shape'] | |
| input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) | |
| interpreter.set_tensor(input_details[0]['index'], input_data) | |
| interpreter.invoke() | |
| output_data = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))] | |
| print(output_data) | |
| def main(_argv): | |
| save_tflite() | |
| demo() | |
| if __name__ == '__main__': | |
| try: | |
| app.run(main) | |
| except SystemExit: | |
| pass | |