简单玩玩TensorFlow的Post Training Quantization

简介: 传送门: 模型量化加速:[https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd](https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd) TF量化训练:[https://www.atatech.o

传送门:
模型量化加速:https://www.atatech.org/articles/132554?spm=ata.13261165.0.0.453bd67dhgqUkd
TF量化训练:https://www.atatech.org/articles/127543?spm=ata.13261165.0.0.453bd67dhgqUkd

量化对端上深度学习模型推理可以说是必选项了,前面的文章已经多次介绍过Quantization-aware training怎么操作,本文简单介绍下Post Training量化。

话说二者有什么区别呢,Quantization-aware training是在训练过程中进行量化,能够更好保持量化后模型的性能。Post training quantization意思是训练玩的模型直接拿来量化,通过在一组sample data推理模型,统计量化所需要的参数[min,max]。通常Post training quantization的精度损失大于Quantization-aware training,所以在以往的工作中我们主要推荐使用Quantization-aware training。

那Post training quantization是不是就毫无用处了呢?显然不是。

  1. Post training quantization在一些相对”重“的模型上,精度损失很小
  2. Quantization-aware training训练速度会慢一些,
  3. Post training quantization只需要模型就可以完成,比如如果模型是从其他训练框架转换而来,这时候就只能使用Post training quantization

根据官方的训练教程:
https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/lite/tutorials/full_integer_post_training_quant.ipynb
我改造了一个更加方便适配各种输入模型的脚本, 直接从文件系统读取Sample图像:

import tensorflow.compat.v1 as tf
#Eager Mode is essential!
tf.enable_eager_execution()

import sys
import glob
if sys.version_info.major >= 3:
    import pathlib
else:
    import pathlib2 as pathlib
import random
import cv2
import numpy as np

tf.logging.set_verbosity(tf.logging.DEBUG)
pb_file='model.pb'
input_arrays=['net_input']
output_arrays=['net_output']
input_shapes=[1, 128, 128, 3]

sample_img_dir='/tmp'


converter =  tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file=pb_file,
                                                       input_arrays=input_arrays,
                                                       output_arrays=output_arrays,
                                                       input_shapes={input_arrays[0]:input_shapes})
converter.allow_custom_ops=True
tflite_model = converter.convert()

tflite_models_dir = pathlib.Path("./tmp/tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

tflite_model_file = tflite_models_dir/"model.tflite"
tflite_model_file.write_bytes(tflite_model)

tf.logging.set_verbosity(tf.logging.DEBUG)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

def preprocess_img(img):
    # preprocess
    img = cv2.resize(img, (128, 128))
    # FLOAT32
    img = img.astype(np.float32)
    # img=(img/128.)-1.
    img = img - 128.
    return img

def create_datastream_from_imgs(img_dir):

    img_path_list=glob.glob(img_dir+'/*.jpg')
    random.shuffle(img_path_list)
    img_path_list=img_path_list[:200]
    imgs_list=[]
    for path in img_path_list:
        img=cv2.imread(path)
        img=preprocess_img(img)
        imgs_list.append(img)
    imgs=np.stack(imgs_list,axis=0)
    return tf.data.Dataset.from_tensor_slices((imgs)).batch(1)


ds=create_datastream_from_imgs(sample_img_dir)

def representative_data_gen():
  for input_value in ds.take(20):
    yield [input_value]

converter.representative_dataset = representative_data_gen
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

最初我在tf1.14上测试,遇到段错误的问题,也许某些同学会遇到:

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

这时候,只需要升级到tf-nightly版本就可以解决了。

成功后得到了这样的模型:
Screenshot from 2019-08-21 11-49-28.png
和Quantization-aware training得到的模型略有不同的是,输入输出仍然是Float的,而Quantization-aware training的模型输入输出都是int8的,所以输入Node之后和输出的Node之前,它相应添加了Quantize和DeQuantize的Node。

在一个简单的分类任务上测试了一下,精度下降1.4%,马马虎虎还可以接受

目录
相关文章
|
4月前
|
机器学习/深度学习 测试技术 API
【Python-Keras】Keras搭建神经网络模型的Model解析与使用
这篇文章详细介绍了Keras中搭建神经网络模型的`Model`类及其API方法,包括模型配置、训练、评估、预测等,并展示了如何使用Sequential模型和函数式模型来构建和训练神经网络。
48 1
|
7月前
|
机器学习/深度学习 TensorFlow API
TensorFlow的扩展库:TensorFlow Probability与TensorFlow Quantum
【4月更文挑战第17天】TensorFlow的扩展库TensorFlow Probability和TensorFlow Quantum开辟了机器学习和量子计算新纪元。TensorFlow Probability专注于概率推理和统计分析,集成深度学习,支持贝叶斯推断和变分推断,提供自动微分及丰富的概率模型工具。其Bijector组件允许复杂随机变量转换,增强建模能力。另一方面,TensorFlow Quantum结合量子计算与深度学习,处理量子数据,构建量子-经典混合模型,应用于化学模拟、量子控制等领域,内置量子计算基元和高性能模拟器。
|
7月前
|
机器学习/深度学习 PyTorch TensorFlow
【TensorFlow】TF介绍及代码实践
【4月更文挑战第1天】TF简介及代码示例学习
99 0
|
机器学习/深度学习 数据可视化 数据挖掘
PyTorch Geometric (PyG) 入门教程
PyTorch Geometric是PyTorch1的几何图形学深度学习扩展库。本文旨在通过介绍PyTorch Geometric(PyG)中常用的方法等内容,为新手提供一个PyG的入门教程。
PyTorch Geometric (PyG) 入门教程
|
存储 并行计算 PyTorch
基于Pytorch中安装torch_geometric简单详细完整版
基于Pytorch中安装torch_geometric简单详细完整版
1319 0
基于Pytorch中安装torch_geometric简单详细完整版
|
并行计算 数据可视化 PyTorch
Pytorch教程[09]Tensorboard
Pytorch教程[09]Tensorboard
Pytorch教程[09]Tensorboard
|
并行计算 PyTorch 算法框架/工具
PyTorch Geometric (PyG) 安装教程
以下根据PyTorch和对应的cuda版本来写PyG的安装方式。对应可行的安装时间会对应附上。 由于我在遇到对应情况时才能撰写对应博文,更多情况看以后我会不会遇上吧。
PyTorch Geometric (PyG) 安装教程
|
TensorFlow 算法框架/工具
TensorFlow教程(2)-基本函数使用
本文主要介绍tf.argmax,tf.reduce_mean(),tf.reduce_sum(),tf.equal()的使用
150 0
TensorFlow教程(2)-基本函数使用
|
TensorFlow 算法框架/工具 开发工具
TF学习——TF之TensorFlow Slim:TensorFlow Slim的简介、安装、使用方法之详细攻略
TF学习——TF之TensorFlow Slim:TensorFlow Slim的简介、安装、使用方法之详细攻略
TF学习——TF之TensorFlow Slim:TensorFlow Slim的简介、安装、使用方法之详细攻略
|
移动开发 算法 算法框架/工具
Py之keras-retinanet:keras-retinanet的简介、安装、使用方法之详细攻略
Py之keras-retinanet:keras-retinanet的简介、安装、使用方法之详细攻略
Py之keras-retinanet:keras-retinanet的简介、安装、使用方法之详细攻略