tensorflow object detection API使用之GPU训练实现宠物识别

简介: 之前写过几篇关于tensorflow object detection API使用的相关文章分享,收到不少关注与鼓励,所以决定再写一篇感谢大家肯定与支持。在具体介绍与解释之前,首先简单说一下本人测试与运行的系统与软件环境与版本

猫狗识别概述

之前写过几篇关于tensorflow object detection API使用的相关文章分享,收到不少关注与鼓励,所以决定再写一篇感谢大家肯定与支持。在具体介绍与解释之前,首先简单说一下本人测试与运行的系统与软件环境与版本

  • Windows 10 64位

  • Python3.6

  • Tensorflow 1.10

  • Object detection api

  • CUDA9.0+cuDNN7.0

下面就说说我是一步一步怎么做的,这个其中CPU训练与GPU训练速度相差很大,另外就是GPU训练时候经常遇到OOM问题,导致训练会停下来。

第一步

下载与安装tensorflow与object detection API模块tensorflow安装与配置执行下面的命令即可

Python –m pip install –upgrade tensorflow-gpu

Object Detection API下载首先执行

git clone https://github.com/tensorflow/models.git D:/tensorflow/models

然后安装protoc-3.4.0-win32执行一个命令行如下:

544b1a05dbeb878d5bbfc0cbc6df0055fdfe8e87

第二步:

下载Oxford-IIIT Pet数据制作tfrecord数据,首先从这里下载数据

http://www.robots.ox.ac.uk/~vgg/data/pets/

记得Dataset与Groundtruth data都需要下载。

然后执行下面的命令即可生成tfrecord

ef565cc76b7f36a7fb21e7a0d695d98c6cd58545

第三步:

使用预训练迁移学习进行训练,这里我使用的是SSD mobilenet的预训练模型,需要修改pipeline config文件与提供的分类描述文件分别为

- ssd_mobilenet_v1_pets.config
- pet_label_map.pbtxt

需要注意的是

ssd_mobilenet_v1_pets.config

文件中PATH_TO_BE_CONFIGURED修改为实际文件所在路径即可。

第四步

执行训练,这个是只需要执行下面命令就可以训练

python object_detection/model_main.py --model_dir=D:\tensorflow\my_train\models\train --pipeline_config_path=D:\tensorflow\my_train\models\ssd_mobilenet_v1_pets.config --num_train_steps=1000 --num_eval_steps=200 --logalsotostderr

但是这个只会在CPU上正常工作,当使用GPU执行训练此数据集的时候,你就会得到一个很让你崩溃的错误

ERROR:tensorflow:Model diverged with loss = NaN
…..
tensorflow.python.training.basic_session_run_hooks.NanLossDuringTrainingError: NaN loss during training

刚开始的我是在CPU上训练的执行这个命令一切正常,但是训练速度很慢,然后有人向我反馈说GPU上无法训练有这个问题,我尝试以后遇到上面的这个错误,于是我就开始了我漫长的查错,最终在github上发现了这个帖子:

https://github.com/tensorflow/models/issues/4881

官方open的issue,暂时大家还没有好办法解决,使用pet的数据集在GPU训练时候发生。帖子里面给出解决方案是使用legacy的train解决,于是我尝试了下面的命令:

73b9b4356f39167b101cf4b6f01d1af5ba8caa41

python object_detection/legacy/train.py --pipeline_config_path=D:/tensorflow/my_train/models/ssd_mobilenet_v1_pets.config --train_dir=D:/tensorflow/my_train/models/train –alsologtostderr

发现GPU上的训练可以正常跑啦,有图为证:

e4f4fcb0704000a196499aa08bc398a903fd8017

但是千万别高兴的太早,以为GPU训练对显存与内存使用是基于贪心算法,它会一直尝试获取更多内存,大概训练了100左右step就会爆出如下的错误:

tensorflow.python.framework.errors_impl.InternalErrorDst tensor is not initialized.

网络使用GPU训练时,一般当GPU显存被占满的时候会出现这个错误
解决的方法,就是在训练命令执行之前,首先执行下面的命令行:

Windows SET CUDA_VISIBLE_DEVICES=0
Linux export CUDA_VISIBLE_DEVICES=0

然后训练就会很顺利的执行下去

这个时候你就可以启动tensorboard查看训练过程啦,我的训练时候损失如下:

cd046b98d9ddb4e37d9ced72cdba4062280dc688

差不多啦,Ctrl+C停止训练,使用下面的命令行导出模型:

1da86c4a4b525b0e8a35e26e486aca972c08092a

导出之后,就可以使用测试图像进行测试啦!

第五步

模型使用,网络上随便找一张猫狗在一起的图像作为测试图像,通过下面的代码实现加载模型,调用tensorflow与opencv相关API函数读取模型与图像,运行代码测试结果如下:

c297a6a2954eb3a0a80623d65854b14a0f52da65

完整测试程序代码如下:

import os
import sys
import tarfile

import cv2
import numpy as np
import tensorflow as tf

sys.path.append("..")
from utils import label_map_util
from utils import visualization_utils as vis_util

##################################################
# 作者:贾志刚
# 微信:gloomy_fish
# tensorflow object detection tutorial
##################################################

# Path to frozen detection graph
PATH_TO_CKPT = 'D:/tensorflow/pet_model/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('D:/tensorflow/my_train/data''pet_label_map.pbtxt')

NUM_CLASSES = 37
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb'as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)


def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)


with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        image_np = cv2.imread("D:/images/test.jpg")
        cv2.imshow("input=QQ+57558865", image_np)
        print(image_np.shape)
        # image_np == [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.
        vis_util.visualize_boxes_and_labels_on_image_array(
              image_np,
              np.squeeze(boxes),
              np.squeeze(classes).astype(np.int32),
              np.squeeze(scores),
              category_index,
              use_normalized_coordinates=True,
              min_score_thresh=0.2,
              line_thickness=8)
        cv2.imshow('object detection', image_np)
        cv2.imwrite("D:/run_result.png", image_np)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

sess.close()



原文发布时间为:2018-09-17
本文作者:gloomyfish
本文来自云栖社区合作伙伴“OpenCV学堂”,了解相关信息可以关注“OpenCV学堂”。
相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
2月前
|
编解码 人工智能 缓存
自学记录鸿蒙API 13:实现多目标识别Object Detection
多目标识别技术广泛应用于动物识别、智能相册分类和工业检测等领域。本文通过学习HarmonyOS的Object Detection API(API 13),详细介绍了如何实现一个多目标识别应用,涵盖从项目初始化、核心功能实现到用户界面设计的全过程。重点探讨了目标类别识别、边界框生成、高精度置信度等关键功能,并分享了性能优化与功能扩展的经验。最后,作者总结了学习心得,并展望了未来结合语音助手等创新应用的可能性。如果你对多目标识别感兴趣,不妨从基础功能开始,逐步实现自己的创意。
212 60
|
6月前
|
UED 开发工具 iOS开发
Uno Platform大揭秘:如何在你的跨平台应用中,巧妙融入第三方库与服务,一键解锁无限可能,让应用功能飙升,用户体验爆棚!
【8月更文挑战第31天】Uno Platform 让开发者能用同一代码库打造 Windows、iOS、Android、macOS 甚至 Web 的多彩应用。本文介绍如何在 Uno Platform 中集成第三方库和服务,如 Mapbox 或 Google Maps 的 .NET SDK,以增强应用功能并提升用户体验。通过 NuGet 安装所需库,并在 XAML 页面中添加相应控件,即可实现地图等功能。尽管 Uno 平台减少了平台差异,但仍需关注版本兼容性和性能问题,确保应用在多平台上表现一致。掌握正确方法,让跨平台应用更出色。
83 0
|
6月前
|
XML 缓存 API
【Azure API 管理】使用APIM进行XML内容读取时遇见的诡异错误 Expression evaluation failed. Object reference not set to an instance of an object.
【Azure API 管理】使用APIM进行XML内容读取时遇见的诡异错误 Expression evaluation failed. Object reference not set to an instance of an object.
|
6月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
52 0
|
6月前
|
机器学习/深度学习 API 算法框架/工具
【Tensorflow+keras】Keras API两种训练GAN网络的方式
使用Keras API以两种不同方式训练条件生成对抗网络(CGAN)的示例代码:一种是使用train_on_batch方法,另一种是使用tf.GradientTape进行自定义训练循环。
66 5
|
7月前
|
机器学习/深度学习 TensorFlow API
Keras是一个高层神经网络API,由Python编写,并能够在TensorFlow、Theano或CNTK之上运行。Keras的设计初衷是支持快速实验,能够用最少的代码实现想法,并且能够方便地在CPU和GPU上运行。
Keras是一个高层神经网络API,由Python编写,并能够在TensorFlow、Theano或CNTK之上运行。Keras的设计初衷是支持快速实验,能够用最少的代码实现想法,并且能够方便地在CPU和GPU上运行。
|
7月前
|
JSON JavaScript API
JS【详解】Map (含Map 和 Object 的区别,Map 的常用 API,Map与Object 的性能对比,Map 的应用场景和不适合的使用场景)
JS【详解】Map (含Map 和 Object 的区别,Map 的常用 API,Map与Object 的性能对比,Map 的应用场景和不适合的使用场景)
206 0
|
8月前
|
Java API
API:object当中的各种方法刨析(今日份:equals toString)
API:object当中的各种方法刨析(今日份:equals toString)
|
8月前
|
Java API
JavaSE——常用API进阶一(1/3)-Object类(Object类的作用、Object类的常见方法-toString方法、equal方法、clone方法)
JavaSE——常用API进阶一(1/3)-Object类(Object类的作用、Object类的常见方法-toString方法、equal方法、clone方法)
57 0
|
9月前
|
机器学习/深度学习 人工智能 API
人工智能应用工程师技能提升系列2、——TensorFlow2——keras高级API训练神经网络模型
人工智能应用工程师技能提升系列2、——TensorFlow2——keras高级API训练神经网络模型
104 0

热门文章

最新文章