简单玩玩TensorFlow的Post Training Quantization

简介: 传送门: 模型量化加速:[https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd](https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd) TF量化训练:[https://www.atatech.o

传送门:
模型量化加速:https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd
TF量化训练:https://www.atatech.org/articles/127543?spm=ata.13261165.0.0.453bd67dhgqUkd

量化对端上深度学习模型推理可以说是必选项了,前面的文章已经多次介绍过Quantization-aware training怎么操作,本文简单介绍下Post Training量化。

话说二者有什么区别呢,Quantization-aware training是在训练过程中进行量化,能够更好保持量化后模型的性能。Post training quantization意思是训练玩的模型直接拿来量化,通过在一组sample data推理模型,统计量化所需要的参数[min,max]。通常Post training quantization的精度损失大于Quantization-aware training,所以在以往的工作中我们主要推荐使用Quantization-aware training。

那Post training quantization是不是就毫无用处了呢?显然不是。

  1. Post training quantization在一些相对”重“的模型上,精度损失很小
  2. Quantization-aware training训练速度会慢一些,
  3. Post training quantization只需要模型就可以完成,比如如果模型是从其他训练框架转换而来,这时候就只能使用Post training quantization

根据官方的训练教程:
https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/lite/tutorials/full_integer_post_training_quant.ipynb
我改造了一个更加方便适配各种输入模型的脚本, 直接从文件系统读取Sample图像:

import tensorflow.compat.v1 as tf
#Eager Mode is essential!
tf.enable_eager_execution()

import sys
import glob
if sys.version_info.major >= 3:
    import pathlib
else:
    import pathlib2 as pathlib
import random
import cv2
import numpy as np

tf.logging.set_verbosity(tf.logging.DEBUG)
pb_file='model.pb'
input_arrays=['net_input']
output_arrays=['net_output']
input_shapes=[1, 128, 128, 3]

sample_img_dir='/tmp'


converter =  tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file=pb_file,
                                                       input_arrays=input_arrays,
                                                       output_arrays=output_arrays,
                                                       input_shapes={input_arrays[0]:input_shapes})
converter.allow_custom_ops=True
tflite_model = converter.convert()

tflite_models_dir = pathlib.Path("./tmp/tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

tflite_model_file = tflite_models_dir/"model.tflite"
tflite_model_file.write_bytes(tflite_model)

tf.logging.set_verbosity(tf.logging.DEBUG)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

def preprocess_img(img):
    # preprocess
    img = cv2.resize(img, (128, 128))
    # FLOAT32
    img = img.astype(np.float32)
    # img=(img/128.)-1.
    img = img - 128.
    return img

def create_datastream_from_imgs(img_dir):

    img_path_list=glob.glob(img_dir+'/*.jpg')
    random.shuffle(img_path_list)
    img_path_list=img_path_list[:200]
    imgs_list=[]
    for path in img_path_list:
        img=cv2.imread(path)
        img=preprocess_img(img)
        imgs_list.append(img)
    imgs=np.stack(imgs_list,axis=0)
    return tf.data.Dataset.from_tensor_slices((imgs)).batch(1)


ds=create_datastream_from_imgs(sample_img_dir)

def representative_data_gen():
  for input_value in ds.take(20):
    yield [input_value]

converter.representative_dataset = representative_data_gen
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

最初我在tf1.14上测试,遇到段错误的问题,也许某些同学会遇到:

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

这时候,只需要升级到tf-nightly版本就可以解决了。

成功后得到了这样的模型:
Screenshot from 2019-08-21 11-49-28.png
和Quantization-aware training得到的模型略有不同的是,输入输出仍然是Float的,而Quantization-aware training的模型输入输出都是int8的,所以输入Node之后和输出的Node之前,它相应添加了Quantize和DeQuantize的Node。

在一个简单的分类任务上测试了一下,精度下降1.4%,马马虎虎还可以接受

目录
相关文章
|
Rust Kubernetes 负载均衡
Sentinel 2.0:流量治理全面升级
Sentinel 诞生于 2012 年,诞生之初主要为支撑双 11 的限流降级等稳定性场景,后续逐渐在阿里集团内部迅速发展成为基础模块,覆盖了所有流量稳定性的核心场景。
Sentinel 2.0:流量治理全面升级
|
存储 缓存 负载均衡
【2022持续更新】大数据最全知识点整理-HBase篇
【2022持续更新】大数据最全知识点整理-HBase篇
1441 0
【2022持续更新】大数据最全知识点整理-HBase篇
|
Java 应用服务中间件 Maven
Springboot项目将jar包修改为war包操作步骤
Springboot项目将jar包修改为war包操作步骤 文章目录 Springboot项目将jar包修改为war包操作步骤 1.修改jar为war包形式 2.去除Spring Boot内置Tomcat 3.增加Tomcat启动插件 4.使用maven编译程序
908 0
Springboot项目将jar包修改为war包操作步骤
|
新零售 数据采集 分布式计算
6000字干货分享:数据中台项目管理实践分享
本文总结了企业级数据中台项目的实践经验,希望能够为正在规划或者已在实施数据中台类项目的企业和个人提供经验。
6000字干货分享:数据中台项目管理实践分享
|
资源调度
如何科学地预估工时?
PERT(Program Evaluation and Review Technique)即计划评审技术,最早是由美国海军在计划和控制北极星导弹的研制时发展起来的。PERT技术使原先估计的、研制北极星潜艇的时间缩短了两年。
如何科学地预估工时?
|
人工智能 安全 物联网
惊爆!!!一条命令,瞬间让普通用户提权为root,赶紧修复!
该漏洞非常容易利用,允许任何未经授权的用户通过在其默认配置中利用此漏洞,来获得易受攻击主机上的完全root权限。
589 0
惊爆!!!一条命令,瞬间让普通用户提权为root,赶紧修复!
|
Java
IntelliJ IDEA - 如何找到 Class 文件的 Java 源码文件进行 Debug?
IntelliJ IDEA - 如何找到 Class 文件的 Java 源码文件进行 Debug?
866 0
IntelliJ IDEA - 如何找到 Class 文件的 Java 源码文件进行 Debug?
|
新零售 机器学习/深度学习 运维
数字化转型的本质、路径、阶段和挑战
企业数字化转型需要协同企业战略,而不是追求眼前效益的战术。
709 0
数字化转型的本质、路径、阶段和挑战
|
自然语言处理
【NLP最佳实践】Huggingface Transformers实战教程
【NLP最佳实践】Huggingface Transformers实战教程
893 0
【NLP最佳实践】Huggingface Transformers实战教程
|
安全 前端开发 API
登录新体验!极光认证,了解一下!
登录新体验!极光认证,了解一下!
登录新体验!极光认证,了解一下!