TensorFlow Lite开发系列之python接口解析(一)

本文涉及的产品
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 环境: tensorflow2.x, 一定要使用linux系统,后期转换模型windows会出现bug

API解析


微信图片_20230203192252.png


常用的就这4个类,常用方法介绍:


# model_path:TF-Lite Flatbuffer文件的路径
# model_content:模型的内容
tf.lite.Interpreter(
    model_path=None, model_content=None, experimental_delegates=None
)


allocate_tensors()   # 加载所有的tensor


get_input_details()  # 获取模型输入详细信息


get_output_details()  # 获取模型输出详细信息


 # 获取输入张量的值(获取副本),该值可以从get_output_details中的“索引”字段中获得
get_tensor( 
    tensor_index
)


get_tensor_details()   # 返回值:包含张量信息的字典列表


invoke()  # 进行推理, 在调用它之前,请确保设置输入大小,分配张量和填充值


# tensor_index:要设置的张量的张量索引。该值可以从get_input_details中的“索引”字段中获得
# value:要设置的张量值
set_tensor(
    tensor_index, value
)


# Converting a SavedModel to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
# Converting a tf.Keras model to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Converting ConcreteFunctions to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_concrete_functions([func])
tflite_model = converter.convert()


tf.lite.Optimize 
类变量
  DEFAULT
  OPTIMIZE_FOR_LATENCY
  OPTIMIZE_FOR_SIZE


完整例子:


1. 训练一个模型用于后面转换TFLite模型与推理


import tensorflow as tf
from tensorflow import keras
# 读取数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = tf.expand_dims(x_train, 3)
y_train = keras.utils.to_categorical(y_train, num_classes=10)
datasets = tf.data.Dataset.from_tensor_slices((x_train, y_train))
datasets = datasets.repeat(1).batch(10)
# 定义模型
img = keras.Input(shape=[28, 28, 1])
x = keras.layers.Conv2D(filters=64, kernel_size=4, strides=1, padding='SAME', activation='relu')(img)
x = keras.layers.AveragePooling2D(pool_size=2, strides=2, padding='SAME')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dropout(0.15)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(512, activation='relu')(x)
y_pred = keras.layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=img, outputs=y_pred)
model.compile(optimizer=keras.optimizers.Adam(0.01),
              loss=keras.losses.categorical_crossentropy,
              metrics=['AUC', 'accuracy'])
model.fit(datasets, epochs=1)
model.save('./model.h5')


2. 转换模型


import tensorflow as tf
from tensorflow import keras
model = tf.keras.models.load_model("./model.h5")
model.summary()
converter = tf.lite.TFLiteConverter.from_keras_model(model)  # 生成转化器
tflite_model = converter.convert()  # 进行转换
open('./converted_model.tflite', 'wb').write(tflite_model)  # 写入


3. 推理


import tensorflow as tf
from tensorflow import keras
import numpy as np
interpreter = tf.lite.Interpreter(model_path='./converted_model.tflite')  # 读入并生成interpreter
interpreter.allocate_tensors()  # 加载所有的张量
input_details = interpreter.get_input_details()  # 获取输入的信息
output_details = interpreter.get_output_details()  # 获取输出的信息
# 指定随机数进行预测
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
# 指定输入数据
interpreter.set_tensor(input_details[0]['index'], input_data)
# 调用模型进行推理
interpreter.invoke()
# 根据tensor索引取出推理的结果
tflite_result = interpreter.get_tensor(output_details[0]['index'])
目录
相关文章
|
5天前
|
数据采集 JSON API
深入解析:使用 Python 爬虫获取淘宝店铺所有商品接口
本文介绍如何使用Python结合淘宝开放平台API获取指定店铺所有商品数据。首先需注册淘宝开放平台账号、创建应用并获取API密钥,申请接口权限。接着,通过构建请求、生成签名、调用接口(如`taobao.items.search`和`taobao.item.get`)及处理响应,实现数据抓取。代码示例展示了分页处理和错误处理方法,并强调了调用频率限制、数据安全等注意事项。此技能对开发者和数据分析师极具价值。
|
2月前
|
数据可视化 前端开发 测试技术
接口测试新选择:Postman替代方案全解析
在软件开发中,接口测试工具至关重要。Postman长期占据主导地位,但随着国产工具的崛起,越来越多开发者转向更适合中国市场的替代方案——Apifox。它不仅支持中英文切换、完全免费不限人数,还具备强大的可视化操作、自动生成文档和API调试功能,极大简化了开发流程。
|
22天前
|
存储 索引 Python
Python入门:6.深入解析Python中的序列
在 Python 中,**序列**是一种有序的数据结构,广泛应用于数据存储、操作和处理。序列的一个显著特点是支持通过**索引**访问数据。常见的序列类型包括字符串(`str`)、列表(`list`)和元组(`tuple`)。这些序列各有特点,既可以存储简单的字符,也可以存储复杂的对象。 为了帮助初学者掌握 Python 中的序列操作,本文将围绕**字符串**、**列表**和**元组**这三种序列类型,详细介绍其定义、常用方法和具体示例。
Python入门:6.深入解析Python中的序列
|
22天前
|
存储 Linux iOS开发
Python入门:2.注释与变量的全面解析
在学习Python编程的过程中,注释和变量是必须掌握的两个基础概念。注释帮助我们理解代码的意图,而变量则是用于存储和操作数据的核心工具。熟练掌握这两者,不仅能提高代码的可读性和维护性,还能为后续学习复杂编程概念打下坚实的基础。
Python入门:2.注释与变量的全面解析
|
5天前
|
存储 JSON API
Python测试淘宝店铺所有商品接口的详细指南
本文详细介绍如何使用Python测试淘宝店铺商品接口,涵盖环境搭建、API接入、签名生成、请求发送、数据解析与存储、异常处理等步骤。通过具体代码示例,帮助开发者轻松获取和分析淘宝店铺商品数据,适用于电商运营、市场分析等场景。遵守法规、注意调用频率限制及数据安全,确保应用的稳定性和合法性。
|
6天前
|
机器学习/深度学习 JSON 算法
淘宝拍立淘按图搜索API接口系列的应用与数据解析
淘宝拍立淘按图搜索API接口是阿里巴巴旗下淘宝平台提供的一项基于图像识别技术的创新服务。以下是对该接口系列的应用与数据解析的详细分析
|
21天前
|
存储 人工智能 程序员
通义灵码AI程序员实战:从零构建Python记账本应用的开发全解析
本文通过开发Python记账本应用的真实案例,展示通义灵码AI程序员2.0的代码生成能力。从需求分析到功能实现、界面升级及测试覆盖,AI程序员展现了需求转化、技术选型、测试驱动和代码可维护性等核心价值。文中详细解析了如何使用Python标准库和tkinter库实现命令行及图形化界面,并生成单元测试用例,确保应用的稳定性和可维护性。尽管AI工具显著提升开发效率,但用户仍需具备编程基础以进行调试和优化。
210 9
|
1月前
|
API 文件存储 数据安全/隐私保护
python 群晖nas接口(一)
这段代码展示了如何通过群晖NAS的API获取认证信息(SID)并列出指定文件夹下的所有文件。首先,`get_sid()`函数通过用户名和密码登录NAS,获取会话ID(SID)。接着,`list_file(filePath, sid)`函数使用该SID访问FileStation API,列出给定路径`filePath`下的所有文件。注意需替换`yourip`、`username`和`password`为实际值。
93 18
|
1月前
|
监控 算法 安全
内网桌面监控软件深度解析:基于 Python 实现的 K-Means 算法研究
内网桌面监控软件通过实时监测员工操作,保障企业信息安全并提升效率。本文深入探讨K-Means聚类算法在该软件中的应用,解析其原理与实现。K-Means通过迭代更新簇中心,将数据划分为K个簇类,适用于行为分析、异常检测、资源优化及安全威胁识别等场景。文中提供了Python代码示例,展示如何实现K-Means算法,并模拟内网监控数据进行聚类分析。
42 10
|
1月前
|
API Python
python泛微e9接口开发
通过POST请求向指定IP的API注册设备以获取`secrit`和`spk`。请求需包含`appid`、`loginid`、`pwd`等头信息。响应中包含状态码、消息及`secrit`(注意拼写)、`secret`和`spk`字段。示例代码使用`curl`命令发送请求,成功后返回相关信息。
46 5

热门文章

最新文章