简单玩玩TensorFlow的Post Training Quantization-阿里云开发者社区

开发者社区> 人工智能> 正文
登录阅读全文

简单玩玩TensorFlow的Post Training Quantization

简介: 传送门: 模型量化加速:[https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd](https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd) TF量化训练:[https://www.atatech.o

传送门:
模型量化加速:https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd
TF量化训练:https://www.atatech.org/articles/127543?spm=ata.13261165.0.0.453bd67dhgqUkd

量化对端上深度学习模型推理可以说是必选项了,前面的文章已经多次介绍过Quantization-aware training怎么操作,本文简单介绍下Post Training量化。

话说二者有什么区别呢,Quantization-aware training是在训练过程中进行量化,能够更好保持量化后模型的性能。Post training quantization意思是训练玩的模型直接拿来量化,通过在一组sample data推理模型,统计量化所需要的参数[min,max]。通常Post training quantization的精度损失大于Quantization-aware training,所以在以往的工作中我们主要推荐使用Quantization-aware training。

那Post training quantization是不是就毫无用处了呢?显然不是。

  1. Post training quantization在一些相对”重“的模型上,精度损失很小
  2. Quantization-aware training训练速度会慢一些,
  3. Post training quantization只需要模型就可以完成,比如如果模型是从其他训练框架转换而来,这时候就只能使用Post training quantization

根据官方的训练教程:
https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/lite/tutorials/full_integer_post_training_quant.ipynb
我改造了一个更加方便适配各种输入模型的脚本, 直接从文件系统读取Sample图像:

import tensorflow.compat.v1 as tf
#Eager Mode is essential!
tf.enable_eager_execution()

import sys
import glob
if sys.version_info.major >= 3:
    import pathlib
else:
    import pathlib2 as pathlib
import random
import cv2
import numpy as np

tf.logging.set_verbosity(tf.logging.DEBUG)
pb_file='model.pb'
input_arrays=['net_input']
output_arrays=['net_output']
input_shapes=[1, 128, 128, 3]

sample_img_dir='/tmp'


converter =  tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file=pb_file,
                                                       input_arrays=input_arrays,
                                                       output_arrays=output_arrays,
                                                       input_shapes={input_arrays[0]:input_shapes})
converter.allow_custom_ops=True
tflite_model = converter.convert()

tflite_models_dir = pathlib.Path("./tmp/tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

tflite_model_file = tflite_models_dir/"model.tflite"
tflite_model_file.write_bytes(tflite_model)

tf.logging.set_verbosity(tf.logging.DEBUG)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

def preprocess_img(img):
    # preprocess
    img = cv2.resize(img, (128, 128))
    # FLOAT32
    img = img.astype(np.float32)
    # img=(img/128.)-1.
    img = img - 128.
    return img

def create_datastream_from_imgs(img_dir):

    img_path_list=glob.glob(img_dir+'/*.jpg')
    random.shuffle(img_path_list)
    img_path_list=img_path_list[:200]
    imgs_list=[]
    for path in img_path_list:
        img=cv2.imread(path)
        img=preprocess_img(img)
        imgs_list.append(img)
    imgs=np.stack(imgs_list,axis=0)
    return tf.data.Dataset.from_tensor_slices((imgs)).batch(1)


ds=create_datastream_from_imgs(sample_img_dir)

def representative_data_gen():
  for input_value in ds.take(20):
    yield [input_value]

converter.representative_dataset = representative_data_gen
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

最初我在tf1.14上测试,遇到段错误的问题,也许某些同学会遇到:

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

这时候,只需要升级到tf-nightly版本就可以解决了。

成功后得到了这样的模型:
Screenshot from 2019-08-21 11-49-28.png
和Quantization-aware training得到的模型略有不同的是,输入输出仍然是Float的,而Quantization-aware training的模型输入输出都是int8的,所以输入Node之后和输出的Node之前,它相应添加了Quantize和DeQuantize的Node。

在一个简单的分类任务上测试了一下,精度下降1.4%,马马虎虎还可以接受

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

分享:
人工智能
使用钉钉扫一扫加入圈子
+ 订阅

了解行业+人工智能最先进的技术和实践,参与行业+人工智能实践项目

其他文章