MobileNet V1官方预训练模型的使用

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: MobileNet V1官方预训练模型的使用

MobileNet V1官方预训练模型的使用


最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。

1. 下载网络结构及模型

1.1 下载MobileNet V1定义网络结构的文件

MobileNet V1的网络结构可以直接从官方Github库中下载定义网络结构的文件,地址为:https://raw.githubusercontent.com/tensorflow/models/master/research/slim/nets/mobilenet_v1.py

1.2 下载MobileNet V1预训练模型

MobileNet V1预训练的模型文在如下地址中下载:

https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md

打开以上网址,可以看到MobileNet V1官方预训练的模型,官方提供了不同输入尺寸和不同网络中通道数的多个模型,并且提供了每个模型对应的精度。可以根据实际的需要下载对应的模型,如下图所示。

微信图片_20221214204621.png

这里以选择MobileNet_v1_1.0_192为例,表示网络中的所有卷积后的通道数为标准通道数(即1.0倍),输入图像尺寸为192X192。

2. 构建网络结构及加载模型参数

2.1 构建网络结构

在1.1小节中下载mobilenet_v1.py文件后,使用其中的mobilenet_v1函数构建网络结构静态图,如下代码所示。

import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope
slim = tf.contrib.slim
def build_model(inputs):   
    with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
        logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
    scores = end_points['Predictions']
    print(scores)
    #取概率最大的3个类别及其对应概率
    output = tf.nn.top_k(scores, k=3, sorted=True)
    #indices为类别索引,values为概率值
    return output.indices,output.values

上面代码中,使用函数tf.nn.top_k取概率最大的3个类别机器对应概率。

2.2 加载模型参数

CKPT = 'mobilenet_v1_1.0_192.ckpt' 
def load_model(sess):
    loader = tf.train.Saver()
    loader.restore(sess,CKPT)
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3))
classes_tf,scores_tf = build_model(inputs) 
with tf.Session() as sess:
    load_model(sess)

先定义placeholder输入inputs,再通过函数build_model完成静态图的定义。接下来传入tf.Session对象到load_model函数中完成模型加载。

3. 模型测试

3.1 加载Label

网络输出结果为类别的索引值,需要将索引值转为对应的类别字符串。先从官网下载label数据,需要注意的是MobileNet V1使用的是ILSVRC-2012-CLS数据,因此需要下载对应的Label信息(本文后面附件中会提供)。解析Label数据代码如下。

def load_label():
    label=['其他']
    with open('label.txt','r',encoding='utf-8') as r:
        lines = r.readlines()
        for l in lines:
            l = l.strip()
            arr = l.split(',')
            label.append(arr[1])
    return label

3.2 测试结果

使用如下图片进行测试。微信图片_20221214204628.png执行inference.py后,控制台输出结果如下所示。

识别 test_images/test1.png 结果如下:
        No. 0 类别: 军用飞机 概率: 0.9363691
        No. 1 类别: 飞机翅膀 概率: 0.032617383
        No. 2 类别: 炮弹 概率: 0.01853972
识别 test_images/test2.png 结果如下:
        No. 0 类别: 小儿床 概率: 0.9455737
        No. 1 类别: 摇篮 概率: 0.044925883
        No. 2 类别: 板架 概率: 0.007288801

4 完整代码

inference.py完整的代码如下所示。

import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope
import cv2
import os
import numpy as np
slim = tf.contrib.slim
CKPT = 'mobilenet_v1_1.0_192.ckpt' 
dir_path = 'test_images'
def build_model(inputs):   
    with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
        logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
    scores = end_points['Predictions']
    print(scores)
    #取概率最大的5个类别及其对应概率
    output = tf.nn.top_k(scores, k=3, sorted=True)
    #indices为类别索引,values为概率值
    return output.indices,output.values
def load_model(sess):
    loader = tf.train.Saver()
    loader.restore(sess,CKPT)
def get_data(path_list,idx): 
    img_path = images_path[idx]
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(192,192))
    img = np.expand_dims(img,axis=0)
    img = (img/255.0-0.5)*2.0
    return img_path,img
def load_label():
    label=['其他']
    with open('label.txt','r',encoding='utf-8') as r:
        lines = r.readlines()
        for l in lines:
            l = l.strip()
            arr = l.split(',')
            label.append(arr[1])
    return label
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3))
classes_tf,scores_tf = build_model(inputs) 
images_path =[dir_path+'/'+n for n in os.listdir(dir_path)]
label=load_label()
with tf.Session() as sess:
    load_model(sess)
    for i in range(len(images_path)):
        path,img = get_data(images_path,i)
        classes,scores = sess.run([classes_tf,scores_tf],feed_dict={inputs:img})
        print('\n识别',path,'结果如下:')
        for j in range(3):#top 3
            idx = classes[0][j]
            score=scores[0][j]
            print('\tNo.',j,'类别:',label[idx],'概率:',score) 

5. 附件下载

https://download.csdn.net/download/huachao1001/10737491

相关文章
|
6月前
|
机器学习/深度学习 PyTorch 测试技术
|
6月前
|
机器学习/深度学习 异构计算 Python
Bert-vits2最终版Bert-vits2-2.3云端训练和推理(Colab免费GPU算力平台)
对于深度学习初学者来说,JupyterNoteBook的脚本运行形式显然更加友好,依托Python语言的跨平台特性,JupyterNoteBook既可以在本地线下环境运行,也可以在线上服务器上运行。GoogleColab作为免费GPU算力平台的执牛耳者,更是让JupyterNoteBook的脚本运行形式如虎添翼。 本次我们利用Bert-vits2的最终版Bert-vits2-v2.3和JupyterNoteBook的脚本来复刻生化危机6的人气角色艾达王(ada wong)。
Bert-vits2最终版Bert-vits2-2.3云端训练和推理(Colab免费GPU算力平台)
|
6月前
|
机器学习/深度学习 数据采集 PyTorch
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
223 1
|
机器学习/深度学习 存储 自然语言处理
使用QLoRA对Llama 2进行微调的详细笔记
使用QLoRA对Llama 2进行微调是我们常用的一个方法,但是在微调时会遇到各种各样的问题,所以在本文中,将尝试以详细注释的方式给出一些常见问题的答案。这些问题是特定于代码的,大多数注释都是针对所涉及的开源库以及所使用的方法和类的问题。
587 0
|
机器学习/深度学习 人工智能 PyTorch
ResNet详解:网络结构解读与PyTorch实现教程
ResNet详解:网络结构解读与PyTorch实现教程
1630 0
|
Ubuntu TensorFlow 算法框架/工具
ResNet实战:tensorflow2.X版本,ResNet50图像分类任务(小数据集)
本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,今天我和大家一起实现tensorflow2.X版本图像分类任务,分类的模型使用ResNet50。 通过这篇文章你可以学到: 1、如何加载图片数据,并处理数据。 2、如果将标签转为onehot编码 3、如何使用数据增强。 4、如何使用mixup。 5、如何切分数据集。 6、如何加载预训练模型。
1421 0
ResNet实战:tensorflow2.X版本,ResNet50图像分类任务(小数据集)
|
PyTorch Go 算法框架/工具
YOLOv8来啦 | 详细解读YOLOv8的改进模块!YOLOv5官方出品YOLOv8,必卷!
YOLOv8来啦 | 详细解读YOLOv8的改进模块!YOLOv5官方出品YOLOv8,必卷!
2691 0
|
机器学习/深度学习 人工智能 数据挖掘
【Deep Learning B图像分类实战】2023 Pytorch搭建AlexNet、VGG16、GoogleNet等共5个模型实现COIL20数据集图像20分类完整项目(项目已开源)
亮点:代码开源+结构清晰规范+准确率高+保姆级解析+易适配自己数据集+附原始论文+适合新手
378 0
|
机器学习/深度学习 人工智能 自动驾驶
深度学习模型部署综述(ONNX/NCNN/OpenVINO/TensorRT)(下)
今天自动驾驶之心很荣幸邀请到逻辑牛分享深度学习部署的入门介绍,带大家盘一盘ONNX、NCNN、OpenVINO等框架的使用场景、框架特点及代码示例。
深度学习模型部署综述(ONNX/NCNN/OpenVINO/TensorRT)(下)
|
机器学习/深度学习 存储 人工智能
深度学习模型部署综述(ONNX/NCNN/OpenVINO/TensorRT)(上)
今天自动驾驶之心很荣幸邀请到逻辑牛分享深度学习部署的入门介绍,带大家盘一盘ONNX、NCNN、OpenVINO等框架的使用场景、框架特点及代码示例。
深度学习模型部署综述(ONNX/NCNN/OpenVINO/TensorRT)(上)

热门文章

最新文章