如何在Windows系统上使用Object Detection API训练自己的数据?

简介: 之前写了一篇如何在windows系统上安装Tensorflow Object Detection API?(点击跳转)然后就想着把数据集换成自己的数据集进行训练得到自己的目标检测模型。动手之前先学习了一波别人是如何实现的,看了大多数教程都有一个小问题:用VOC2012数据集进行训练当做用自己的数据集。然而,初心想看的是自己的数据集啊!于是就自己来撸一篇教程,方便自己也给别人一些参考吧~

前言

之前写了一篇如何在windows系统上安装Tensorflow Object Detection API?(点击跳转)

然后就想着把数据集换成自己的数据集进行训练得到自己的目标检测模型。动手之前先学习了一波别人是如何实现的,看了大多数教程都有一个小问题:用VOC2012数据集进行训练当做用自己的数据集。

然而,初心想看的是自己的数据集啊!于是就自己来撸一篇教程,方便自己也给别人一些参考吧~

目录

基于自己数据集进行目标检测训练的整体步骤如下:

  • 数据标注,制作VOC格式的数据集
  • 将数据集制作成tfrecord格式
  • 下载预使用的目标检测模型
  • 配置文件和模型
  • 模型训练

这里放一下小詹这个项目的整体截图,方便后边文件的对号入座。


81.jpg


数据标注,制作VOC格式的数据集

数据集当然是第一步,在收集好数据后需要进行数据的标注,考虑到VOC风格,这里推荐使用LabelImg工具进行标注。

92.jpg


至于工具具体怎么用,自己摸索下就好,小詹已经把关键点地方框选出来啦。(Tip: Ctrl+R选择标注文件存放路径)

将数据集制作成tfrecord格式

这一部需要将手动标注的xml文件进行处理,得到标注信息csv文件,之后和图像数据一起制作成tfrecord格式的数据,用于网络训练。

xml转换为csv文件

这一步需要对xml文件进行解析,提取出标注信息存入csv文件,这里直接把小詹的脚步文件(Xml2Csv.py)分享如下,当然文件路径你得换成自己的!

# 将xml文件读取关键信息转化为csv文件
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
def xml_to_csv(path):
    xml_list = []
    for xml_file in glob.glob(path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text)
                     )
            xml_list.append(value)
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    return xml_df
def main():
    image_path = r'E:\Jan_Project\Experiment_1\dataset\test_xml'
    xml_df = xml_to_csv(image_path)
    xml_df.to_csv(r'E:\Jan_Project\Experiment_1\dataset\cancer_test_labels.csv', index=None)
    print('Successfully converted xml to csv.')
if __name__ == '__main__':
    main()


生成tfrecord数据文件

之后在对应文件路径处就有了csv文件,再利用如下脚步自动生成tfrecord。(这是github上生成文件的修改版)

# 将CSV文件和图像数据整合为TFRecords
"""
name: generate_tfrecord.py
Usage:
  # From tensorflow/models/
  # Create train data:
  python generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train.record
  # Create test data:
  python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=test.record
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import io
import pandas as pd
import tensorflow as tf
from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict
flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS
# TO-DO replace this with label map
def class_text_to_int(row_label):
    if row_label == 'yichang':
        return 1
    else:
        None
def split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
def create_tf_example(group, path):
    with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size
    filename = group.filename.encode('utf8')
    image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []
    for index, row in group.object.iterrows():
        xmins.append(row['xmin'] / width)
        xmaxs.append(row['xmax'] / width)
        ymins.append(row['ymin'] / height)
        ymaxs.append(row['ymax'] / height)
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class']))
    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example
def main(csv_input, output_path, image_path):
    writer = tf.python_io.TFRecordWriter(output_path)
    path = image_path
    examples = pd.read_csv(csv_input)
    grouped = split(examples, 'filename')
    for group in grouped:
        tf_example = create_tf_example(group, path)
        writer.write(tf_example.SerializeToString())
    writer.close()
    print('Successfully created the TFRecords: {}'.format(output_path))
if __name__ == '__main__':
#     csv_input = r'E:\Jan_Project\Experiment_1\dataset\cancer_train_labels.csv'
#     output_path = r'E:\Jan_Project\Experiment_1\dataset\train.tfrecord'
#     image_path = r'E:\Jan_Project\Experiment_1\dataset\train_img'
#     main(csv_input, output_path, image_path)
    csv_input = r'E:\Jan_Project\Experiment_1\dataset\cancer_test_labels.csv'
    output_path = r'E:\Jan_Project\Experiment_1\dataset\test.tfrecord'
    image_path = r'E:\Jan_Project\Experiment_1\dataset\test_img'
    main(csv_input, output_path, image_path)


利用上述脚步后便得到了想要的数据格式,小詹这里如图所示:

93.jpg


下载预使用的目标检测模型

准备好训练数据后,选择模型进行训练,下载官方预训练模型【Github】

对于目标检测,可以考虑选择几种最常用的模型:

  • ssd_mobilenet_v1_coco
  • ssd_mobilenet_v2_coco
  • faster_rcnn_resnet50_coco
  • faster_rcnn_resnet101_coco

小詹选择的是上方链接中对应下图的那个,自己视情况而定即可。

94.jpg


下载后解压到对应文件夹中(见小詹放的第一张项目整体图)

配置文件和模型

建立label_map.pbtxt

这里需要针对自己数据集进行修改,格式如下:

item{

  id: 1

  name: 'object'

}

修改

进入tensorflow/models/research/object_detection/samples/config文件夹找到对应自己模型的config文件,针对自己的情况进行修改:

num_classes: 修改为你自己任务的类别数
batch size:2(GPU显存较小的,尽量设置成小数值)
fine_tune_checkpoint: "路径/model.ckpt" #指定“训练模型的检查点文件”
train_input_reader: {
  tf_record_input_reader {
    input_path: "路径/train.tfrecord"
  }
  label_map_path: "路径/label_map.pbtxt"
}
eval_input_reader: {
  tf_record_input_reader {
    input_path: "路径/test.tfrecord"
  }
  label_map_path: "路径/label_map.pbtxt"
  shuffle: false
  num_readers: 1

模型训练

关于训练,要注意batch size大小和网络模型复杂程度,注意显存是否够大?显存不够就OOM(out of memory)了。

训练模型只需要运行object_detection/legacy路径下的train.py程序即可。(当然object_detection API安装是大前提,具体看上一篇文章!)

本地电脑:

python object_detection//legacy//train.py --logtostderr --train_dir=E://Jan_Project//Experiment_1//model --pipeline_config_path=E://Jan_Project//Experiment_1//training//faster_rcnn_inception_v2_coco.config

如果配置不够,可以云服务器上跑。以下是训练过程截图。

95.jpg


训练后还可以导出模型,用于检测测试。

#From tensorflow/modles/research/object_detection/
python export_inference_graph.py 
--input_type image_tensor 
--pipeline_config_path 路径/***.config
--trained_checkpoint_prefix 路径/model.ckpt-numbers #选择最近的一个或确认收敛到最优的一个
--output_directory 路径/my_model/ #模型的输出路径

以上整理于2019-5-20,节日快乐!

相关文章
|
3月前
|
计算机视觉 Windows Python
windows下使用python + opencv读取含有中文路径的图片 和 把图片数据保存到含有中文的路径下
在Windows系统中,直接使用`cv2.imread()`和`cv2.imwrite()`处理含中文路径的图像文件时会遇到问题。读取时会返回空数据,保存时则无法正确保存至目标目录。为解决这些问题,可以使用`cv2.imdecode()`结合`np.fromfile()`来读取图像,并使用`cv2.imencode()`结合`tofile()`方法来保存图像至含中文的路径。这种方法有效避免了路径编码问题,确保图像处理流程顺畅进行。
343 1
|
10天前
|
存储 前端开发 搜索推荐
淘宝 1688 API 接口助力构建高效淘宝代购集运系统
在全球化商业背景下,淘宝代购集运业务蓬勃发展,满足了海外消费者对中国商品的需求。掌握淘宝1688 API接口是构建成功代购系统的關鍵。本文详细介绍如何利用API接口进行系统架构设计、商品数据同步、订单处理与物流集成,以及用户管理和客户服务,帮助你打造一个高效便捷的代购集运系统,实现商业价值与用户满意度的双赢。
|
16天前
|
数据库 数据安全/隐私保护 Windows
Windows远程桌面出现CredSSP加密数据修正问题解决方案
【10月更文挑战第30天】本文介绍了两种解决Windows系统凭据分配问题的方法。方案一是通过组策略编辑器(gpedit.msc)启用“加密数据库修正”并将其保护级别设为“易受攻击”。方案二是通过注册表编辑器(regedit)在指定路径下创建或修改名为“AllowEncryptionOracle”的DWORD值,并将其数值设为2。
44 3
|
24天前
|
监控 安全 测试技术
我们为什么要API管理系统呢?
API 管理系统通过接口标准化与复用、简化开发流程、版本管理、监控与预警、访问控制、数据加密、安全审计、集中管理与共享、协作开发、快速对接外部系统和数据驱动的决策等多方面优势,显著提高开发效率、增强系统可维护性、提升系统安全性、促进团队协作与沟通,并支持业务创新与扩展。
|
3月前
|
缓存 NoSQL Linux
【Azure Redis 缓存】Windows和Linux系统本地安装Redis, 加载dump.rdb中数据以及通过AOF日志文件追加数据
【Azure Redis 缓存】Windows和Linux系统本地安装Redis, 加载dump.rdb中数据以及通过AOF日志文件追加数据
129 1
【Azure Redis 缓存】Windows和Linux系统本地安装Redis, 加载dump.rdb中数据以及通过AOF日志文件追加数据
|
1月前
|
供应链 搜索推荐 数据挖掘
电商ERP系统中电商API接口的应用
电商API接口在电子商务中扮演着至关重要的角色,它们允许开发者将电商功能集成到自己的应用程序中,实现商品检索、订单处理、支付、物流跟踪等功能。以下是关于电商API接口的应用:
|
3月前
|
Web App开发 存储 安全
微软警告数百万Windows用户:切勿冒险丢失所有数据
微软警告数百万Windows用户:切勿冒险丢失所有数据
微软警告数百万Windows用户:切勿冒险丢失所有数据
|
3月前
|
监控 Cloud Native 容灾
核心系统转型问题之API网关在云原生分布式核心系统中的功能如何解决
核心系统转型问题之API网关在云原生分布式核心系统中的功能如何解决
|
3月前
|
UED 开发工具 iOS开发
Uno Platform大揭秘:如何在你的跨平台应用中,巧妙融入第三方库与服务,一键解锁无限可能,让应用功能飙升,用户体验爆棚!
【8月更文挑战第31天】Uno Platform 让开发者能用同一代码库打造 Windows、iOS、Android、macOS 甚至 Web 的多彩应用。本文介绍如何在 Uno Platform 中集成第三方库和服务,如 Mapbox 或 Google Maps 的 .NET SDK,以增强应用功能并提升用户体验。通过 NuGet 安装所需库,并在 XAML 页面中添加相应控件,即可实现地图等功能。尽管 Uno 平台减少了平台差异,但仍需关注版本兼容性和性能问题,确保应用在多平台上表现一致。掌握正确方法,让跨平台应用更出色。
52 0
|
3月前
|
XML 缓存 API
【Azure API 管理】使用APIM进行XML内容读取时遇见的诡异错误 Expression evaluation failed. Object reference not set to an instance of an object.
【Azure API 管理】使用APIM进行XML内容读取时遇见的诡异错误 Expression evaluation failed. Object reference not set to an instance of an object.

热门文章

最新文章