TensorFlow Lite开发系列之python接口解析(一)

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
简介: 环境: tensorflow2.x, 一定要使用linux系统,后期转换模型windows会出现bug

API解析


微信图片_20230203192252.png


常用的就这4个类,常用方法介绍:


# model_path:TF-Lite Flatbuffer文件的路径
# model_content:模型的内容
tf.lite.Interpreter(
    model_path=None, model_content=None, experimental_delegates=None
)


allocate_tensors()   # 加载所有的tensor


get_input_details()  # 获取模型输入详细信息


get_output_details()  # 获取模型输出详细信息


 # 获取输入张量的值(获取副本),该值可以从get_output_details中的“索引”字段中获得
get_tensor( 
    tensor_index
)


get_tensor_details()   # 返回值:包含张量信息的字典列表


invoke()  # 进行推理, 在调用它之前,请确保设置输入大小,分配张量和填充值


# tensor_index:要设置的张量的张量索引。该值可以从get_input_details中的“索引”字段中获得
# value:要设置的张量值
set_tensor(
    tensor_index, value
)


# Converting a SavedModel to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
# Converting a tf.Keras model to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Converting ConcreteFunctions to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_concrete_functions([func])
tflite_model = converter.convert()


tf.lite.Optimize 
类变量
  DEFAULT
  OPTIMIZE_FOR_LATENCY
  OPTIMIZE_FOR_SIZE


完整例子:


1. 训练一个模型用于后面转换TFLite模型与推理


import tensorflow as tf
from tensorflow import keras
# 读取数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = tf.expand_dims(x_train, 3)
y_train = keras.utils.to_categorical(y_train, num_classes=10)
datasets = tf.data.Dataset.from_tensor_slices((x_train, y_train))
datasets = datasets.repeat(1).batch(10)
# 定义模型
img = keras.Input(shape=[28, 28, 1])
x = keras.layers.Conv2D(filters=64, kernel_size=4, strides=1, padding='SAME', activation='relu')(img)
x = keras.layers.AveragePooling2D(pool_size=2, strides=2, padding='SAME')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dropout(0.15)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(512, activation='relu')(x)
y_pred = keras.layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=img, outputs=y_pred)
model.compile(optimizer=keras.optimizers.Adam(0.01),
              loss=keras.losses.categorical_crossentropy,
              metrics=['AUC', 'accuracy'])
model.fit(datasets, epochs=1)
model.save('./model.h5')


2. 转换模型


import tensorflow as tf
from tensorflow import keras
model = tf.keras.models.load_model("./model.h5")
model.summary()
converter = tf.lite.TFLiteConverter.from_keras_model(model)  # 生成转化器
tflite_model = converter.convert()  # 进行转换
open('./converted_model.tflite', 'wb').write(tflite_model)  # 写入


3. 推理


import tensorflow as tf
from tensorflow import keras
import numpy as np
interpreter = tf.lite.Interpreter(model_path='./converted_model.tflite')  # 读入并生成interpreter
interpreter.allocate_tensors()  # 加载所有的张量
input_details = interpreter.get_input_details()  # 获取输入的信息
output_details = interpreter.get_output_details()  # 获取输出的信息
# 指定随机数进行预测
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
# 指定输入数据
interpreter.set_tensor(input_details[0]['index'], input_data)
# 调用模型进行推理
interpreter.invoke()
# 根据tensor索引取出推理的结果
tflite_result = interpreter.get_tensor(output_details[0]['index'])
目录
相关文章
|
24天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
240 55
|
5天前
|
前端开发 搜索推荐 编译器
【01】python开发之实例开发讲解-如何获取影视网站中经过保护后的视频-用python如何下载无法下载的视频资源含m3u8-python插件之dlp-举例几种-详解优雅草央千澈
【01】python开发之实例开发讲解-如何获取影视网站中经过保护后的视频-用python如何下载无法下载的视频资源含m3u8-python插件之dlp-举例几种-详解优雅草央千澈
【01】python开发之实例开发讲解-如何获取影视网站中经过保护后的视频-用python如何下载无法下载的视频资源含m3u8-python插件之dlp-举例几种-详解优雅草央千澈
|
16天前
|
IDE 测试技术 开发工具
10个必备Python调试技巧:从pdb到单元测试的开发效率提升指南
在Python开发中,调试是提升效率的关键技能。本文总结了10个实用的调试方法,涵盖内置调试器pdb、breakpoint()函数、断言机制、logging模块、列表推导式优化、IPython调试、警告机制、IDE调试工具、inspect模块和单元测试框架的应用。通过这些技巧,开发者可以更高效地定位和解决问题,提高代码质量。
126 8
10个必备Python调试技巧:从pdb到单元测试的开发效率提升指南
|
1月前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
168 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
5天前
|
人工智能 编译器 Python
python已经安装有其他用途如何用hbuilerx配置环境-附带实例demo-python开发入门之hbuilderx编译器如何配置python环境—hbuilderx配置python环境优雅草央千澈
python已经安装有其他用途如何用hbuilerx配置环境-附带实例demo-python开发入门之hbuilderx编译器如何配置python环境—hbuilderx配置python环境优雅草央千澈
python已经安装有其他用途如何用hbuilerx配置环境-附带实例demo-python开发入门之hbuilderx编译器如何配置python环境—hbuilderx配置python环境优雅草央千澈
|
30天前
|
存储 API 数据库
使用Python开发获取商品销量详情API接口
本文介绍了使用Python开发获取商品销量详情的API接口方法,涵盖API接口概述、技术选型(Flask与FastAPI)、环境准备、API接口创建及调用淘宝开放平台API等内容。通过示例代码,详细说明了如何构建和调用API,以及开发过程中需要注意的事项,如数据库连接、API权限、错误处理、安全性和性能优化等。
94 5
|
2月前
|
机器学习/深度学习 人工智能 关系型数据库
Python开发
Python开发
43 7
|
2月前
|
JSON API 数据格式
如何使用Python开发天猫获得淘宝买家秀API接口?
本文介绍了如何使用Python开发天猫和淘宝买家秀API接口,包括注册开放平台账号、创建应用获取API权限、构建请求URL、发送请求获取响应及解析数据等步骤,帮助开发者高效获取和处理商品信息与用户评价数据。
44 0
|
测试技术 TensorFlow 算法框架/工具
Python 数据科学入门教程:TensorFlow 目标检测
TensorFlow 目标检测 原文:TensorFlow Object Detection 译者:飞龙 协议:CC BY-NC-SA 4.0 一、引言 你好,欢迎阅读 TensorFlow 目标检测 API 迷你系列。
1698 0
|
1月前
|
人工智能 数据可视化 数据挖掘
探索Python编程:从基础到高级
在这篇文章中,我们将一起深入探索Python编程的世界。无论你是初学者还是有经验的程序员,都可以从中获得新的知识和技能。我们将从Python的基础语法开始,然后逐步过渡到更复杂的主题,如面向对象编程、异常处理和模块使用。最后,我们将通过一些实际的代码示例,来展示如何应用这些知识解决实际问题。让我们一起开启Python编程的旅程吧!