从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: 从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)

从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)


最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。

Tensorflow官方提供的Tensorboard可以可视化神经网络结构图,但是说实话,我几乎从来不用。主要是因为Tensorboard中查看到的图结构太混乱了,包含了网络中所有的计算节点(读取数据节点、网络节点、loss计算节点等等)。更可怕的是,如果一个计算节点是由多个基础计算(如加减乘除等)构成,那么在Tensorboard中会将基础计算节点显示而不是作为一个整体显示(典型的如Squeeze计算节点)。最近为了排查网络结构BUG花费一周时间,因此,狠下心来决定自己写一个工具,将Tensorflow中的图以最简单的方式显示最关键的网络结构。

1 Tensor对象与Operation对象

Tensorflow中,Tensor对象主要用于存储数据如常量和变量(训练参数),Operation对象是计算节点,如卷积计算、反卷积计算、ReLU等等。每一个Operation对象均有输入和输出Tensor,同理,每个Tensor对象均有对应生成该Tensor的Operation对象和使用该Tensor对象作为输入的Operation对象。Tensor和Operation对象内均有相关属性和函数来获取其关联的Operation和Tensor对象,相关属性如下所示。

Tensor对象的op属性指向生成该Tensor的Operation对象。

Tensor对象的consumers()函数获取使用该Tensor对象作为输入的Operation对象。

Operation对象的inputs属性指向该计算节点的输入Tensor对象。

Operation对象的outputs属性执行该计算节点的输出Tensor对象。

如下图所示的网络结构中,调用Tensor_2对象的consumers()函数,返回的是[op_1,op_2]。Tensor_3的op属性指向的是op_1。op_1的inputs属性指向的是[Tensor_1,Tensor_2],op_1的output属性指向的是[Tensor_3]。

image.png

有了Tensor与Operation对应在图中的关联关系,就可以将网络结构给画出来。

2 提取pb文件中的网络结构图

pb文件是将模型参数固化到图文件中,并合并了一些基础计算和删除了反向传播相关计算得到的protobuf协议文件。如果读者还不懂如何将CKPT模型文件转pb文件,请参考我另一篇文章《 Tensorflow MobileNet移植到Android》的第1节部分。有了pb模型文件后,接下来是加载模型,加载pb模型示例代码如下所示。

def read_graph_from_pb(tf_model_path ,input_names,output_name):  
    with open(tf_model_path, 'rb') as f:
        serialized = f.read() 
    tf.reset_default_graph()
    gdef = tf.GraphDef()
    gdef.ParseFromString(serialized) 
    with tf.Graph().as_default() as g:
        tf.import_graph_def(gdef, name='') 
    with tf.Session(graph=g) as sess: 
        OPS=get_ops_from_pb(g,input_names,output_name)
    return OPS

其中,倒数第2行调用到的函数get_ops_from_pb()用于获取网络结构图中指定输入节点和指定输出节点之间的计算节点。之所以要指定输入和输出,是为了将输入之前的计算节点(如加载数据队列等相关计算节点)和输出之后的计算节点(如计算loss等相关计算节点)去除,免得碍眼。函数get_ops_from_pb()实现代码如下。

def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
    if save_ori_network:
        with open('ori_network.txt','w+') as w: 
            OPS=graph.get_operations()
            for op in OPS:
                txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
                w.write(txt+'\n') 
    inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
    output_tf =graph.get_tensor_by_name(output_name) 
    OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] ) 
    with open('network.txt','w+') as w: 
        for op in OPS:
            txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
            w.write(txt+'\n') 
    OPS = sort_ops(OPS)
    OPS = merge_layers(OPS)
    return OPS

在裁剪网络结构(即只保留input_names和output_name之间节点)之前,先将原始的网络结构写入到ori_network.txt中,文件中,每一行写入:输入Tensor---->op---->输出Tensor。接下来调用函数get_ops_from_inputs_outputs获取指定节点之间的节点。并调用sort_ops函数对所有的节点排序,以保证被依赖的节点总是出现在相关节点之前。最后调用merge_layers函数,将一些可以合并的计算合并成一个独立的节点,例如,Squeeze计算相关节点合并成一个单独的Squeeze节点,又如const-->identity两个计算节点可以直接忽略(即删除)。

注意:篇幅有限,这里不再将函数get_ops_from_inputs_outputs、sort_ops、merge_layers贴出,相关代码请前往文尾提供的源码地址中阅读。

3 绘制网络结构

考虑到SVG绘制图形的简单易用优点,将排好序的网络计算节点和相关Tensor对象数据以Javascript字符串的形式写入到HTML中,使用<line>标签绘制箭头,使用<rect>标签绘制矩形,使用<ellipse>标签绘制椭圆,使用<text>标签显示文字。绘制类似于如下所示图像

image.png

注意:篇幅有限,这里不再介绍Javascript代码解析模型结构和SVG显示相关的原理,相关代码请前往文尾提供的源码地址中阅读。

4 测试模型显示

以《MobileNet V1官方预训练模型的使用》文中介绍的MobileNet V1网络结构为例,下载MobileNet_v1_1.0_192文件并压缩后,得到mobilenet_v1_1.0_192_frozen.pb文件。我们还需要知道mobilenet_v1_1.0_192_frozen.pb模型对应的输入和输出Tensor对象的名称,好在MobileNet_v1_1.0_192压缩包中包含文件mobilenet_v1_1.0_192_info.txt。通过该文件可知,输入Tensor的名称为:input:0,输出Tensor名称为:MobilenetV1/Predictions/Reshape_1:0。有了这些信息后,调用函数read_graph_from_pb得到静态图的节点列表对象ops,调用函数gen_graph(ops,"save/path/graph.html")后,在目录save/path中得到graph.html文件,打开graph.html后,显示结果如下。

显示网络结构分两种模式:合并模式和展开模式,分别如下图所示。

5 源码地址

https://github.com/huachao1001/CNNGraph

相关文章
|
7天前
|
机器学习/深度学习 编解码 自动驾驶
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV1,用于移动视觉应用的高效卷积神经网络
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV1,用于移动视觉应用的高效卷积神经网络
31 3
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV1,用于移动视觉应用的高效卷积神经网络
|
7天前
|
机器学习/深度学习 移动开发 测试技术
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
27 1
RT-DETR改进策略【模型轻量化】| 替换骨干网络为MoblieNetV2,含模型详解和完整配置步骤
|
7天前
RT-DETR改进策略【模型轻量化】| 替换骨干网络为 GhostNet V3 2024华为的重参数轻量化模型
RT-DETR改进策略【模型轻量化】| 替换骨干网络为 GhostNet V3 2024华为的重参数轻量化模型
29 2
RT-DETR改进策略【模型轻量化】| 替换骨干网络为 GhostNet V3 2024华为的重参数轻量化模型
|
4天前
|
人工智能 自然语言处理 算法
DeepSeek模型的突破:性能超越R1满血版的关键技术解析
上海AI实验室周伯文团队的最新研究显示,7B版本的DeepSeek模型在性能上超越了R1满血版。该成果强调了计算最优Test-Time Scaling的重要性,并提出了一种创新的“弱到强”优化监督机制的研究思路,区别于传统的“从强到弱”策略。这一方法不仅提升了模型性能,还为未来AI研究提供了新方向。
175 5
|
7天前
|
机器学习/深度学习 文件存储 异构计算
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
16 1
|
2月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
349 55
|
3月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
337 5
|
3月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
137 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
|
3月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
179 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
3月前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
171 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型

热门文章

最新文章

推荐镜像

更多