PyTorch版EfficientDet比官方TF实现快25倍?这个GitHub项目数天狂揽千星

简介: EfficientDet 难复现,复现即趟坑。在此 Github 项目中,开发者 zylo117 开源了 PyTorch 版本的 EfficientDet,速度比原版高 20 余倍。如今,该项目已经登上 Github Trending 热榜。
+关注继续查看

云栖号资讯:【点击查看更多行业资讯
在这里您可以找到不同行业的第一手的上云资讯,还在等什么,快来!

EfficientDet 难复现,复现即趟坑。在此 Github 项目中,开发者 zylo117 开源了 PyTorch 版本的 EfficientDet,速度比原版高 20 余倍。如今,该项目已经登上 Github Trending 热榜。

image

去年 11 月份,谷歌大脑提出兼顾准确率和模型效率的新型目标检测器 EfficientDet,实现了新的 SOTA 结果。前不久,该团队开源了 EfficientDet 的 TensorFlow 实现代码。

如此高效的 EfficientDet 还能更高效吗?最近,有开发者在 GitHub 上开源了「PyTorch 版本的 EfficientDet」。该版本的性能接近原版,但速度是官方 TensorFlow 实现的近 26 倍!

image

目前,该项目在 GitHub 上获得了 957 颗星,最近一天的收藏量接近 300。

GitHub 地址:https://github.com/zylo117

EfficientDet 简介

近年来,在面对广泛的资源约束时(如 3B 到 300B FLOPS),构建兼具准确率和效率的可扩展检测架构成为优化目标检测器的重要问题。基于单阶段检测器范式,谷歌大脑团队的研究者查看了主干网络、特征融合和边界框/类别预测网络的设计选择,发现了两大主要挑战并提出了相应的解决方法:

挑战 1:高效的多尺度特征融合。研究者提出一种简单高效的加权双向特征金字塔网络(BiFPN),该模型引入了可学习的权重来学习不同输入特征的重要性,同时重复应用自上而下和自下而上的多尺度特征融合。

挑战 2:模型缩放。受近期研究的启发,研究者提出一种目标检测器复合缩放方法,即统一扩大所有主干网络、特征网络、边界框/类别预测网络的分辨率/深度/宽度。

谷歌大脑团队的研究者发现,EfficientNets 的效率超过之前常用的主干网络。于是研究者将 EfficientNet 主干网络和 BiFPN、复合缩放结合起来,开发出新型目标检测器 EfficientDet,其准确率优于之前的目标检测器,同时参数量和 FLOPS 比它们少了一个数量级。

下图展示了 EfficientDet 的整体架构,大致遵循单阶段检测器范式。谷歌大脑团队的研究者将在 ImageNet 数据集上预训练的 EfficientNet 作为主干网络,将 BiFPN 作为特征网络,接受来自主干网络的 level 3-7 特征 {P3, P4, P5, P6, P7},并重复应用自上而下和自下而上的双向特征融合。然后将融合后的特征输入边界框/类别预测网络,分别输出目标类别和边界框预测结果。

image

下图展示了多个模型在 COCO 数据集上的性能对比情况。在类似的准确率限制下,EfficientDet 的 FLOPS 仅为 YOLOv3 的 1/28、RetinaNet 的 1/30、NASFPN 的 1/19,所有数字均为单个模型在单一尺度下所得。可以看到,EfficientDet 的计算量较其他检测器少,但准确率优于后者,其中 EfficientDet-D7 获得了当前最优性能。

image

「宅」是第一生产力

项目作者今年 1 月宅家为国出力时,开始陆续尝试各类 EfficientDet PyTorch 版实现,期间趟过了不少坑,也流过几把辛酸泪。但最终得出了非常不错的效果,也是全网第一个跑出接近论文成绩的 PyTorch 版。

我们先来看一下项目作者与 EfficientDet 官方提供代码的测试效果对比。第一张图为官方代码的检测效果,第二张为项目作者的检测效果。项目作者的实现竟然透过汽车的前挡风玻璃检测出了车辆里面的人?!!这样惊艳的检测效果不愧是目前 EfficientDet 的霸榜存在。

image
image

接下来我们来看一下 coco 数据集上目标检测算法的排名,多个屠榜的目标检测网络基于 EfficientDet 构建。一图以言之:

image
来自 paperswithcode

前五里包揽前四,屠榜之势不言而喻,也难怪各类炼金术士们跃跃欲试。但是,EfficientDet 的实现难度貌似与其知名度「成正比」,众炼金师纷纷表示「难训练」「至今未训练好」「谁复现谁被坑」。项目作者也表示「由于谷歌一直不发官方 repository,所以只能民间发力,那些靠 paper 的内容实现出来的真的不容易」。

假期三天,拿下 PyTorch 版 EfficientDet D0 到 D7

项目作者复现结果与论文中并没完全一致,但相较于其他同类复现项目来说,称的上是非常接近了(详细信息可参考项目链接)。

image

值得注意的是,此次项目处理速度比原版快了 20 余倍。

那么为什么之前都没有人复现 EfficientDet 的成绩?具体哪些细节需要注意?

「民间」EfficientDet 的取舍

作者前后试用了两个 GitHub 项目进行实现,但效果并不理想。首先采用的 star 量最高的一个,同时可能也说明了一点,不是 star 越高就越适合。

针对第一个项目,作者表示:「因为 EfficientDet 的特性之一是 BiFPN,它会融合 backbone 输出的任意相邻两层的 feature,但是由于有两层尺寸的宽高是不同的,所以会进行 upsample 或者 pooling 来保证它们宽高一致。而这个作者没有意识到,他不知道从 backbone 抽哪些 feature 出来,他觉得是 backbone 有问题,改了人家的 stride,随便挑了几层,去强迫 backbone 输出他想要的尺寸」

「改了网络结构,pretrained 权值基本就废了,所以作者也发现了,发现训练不下去了」。至此第一个项目画上句号,同时作者提供了官方参数与试用项目作者改后的参数对比链接,有兴趣的朋友可浏览参考链接。

而面向第二个项目,虽然 star 不及前者一半,但显然可靠度更甚前者。作者表示,第二个项目起码在 D0 上有论文成绩的支撑,同时 repo 也提供了 coco 的 pretrained 权值 31.4mAP。然而实操后作者得到 24mAP,同时社区普遍也在 20-22 范围中。

那么此次结果的原因是什么?作者经过反复的思考检测,得到以下 7 点总结,并就此 7 点复盘进行适当得调整,得到了当前项目不错的效果。

一波三折后的答案

针对第二个测试项目的复盘,作者表示一共有 7 个关键点需要额外注意:

  1. 第二个项目的 BN 实现有问题:BatchNorm 是有一个参数,叫做 momentum,用来调整新旧均值的比例,从而调整移动平均值的计算方式的。
  2. Depthwise-Separatable Conv2D 的错误实现。
  3. 误解了 maxpool2d 的参数,kernel_size 和 stride。
  4. 减少通道的卷积后面,没有进行 BN
  5. backbone feature 抽头抽错了
  6. Conv 和 pooling,没有用到 same padding
  7. 没有能正确的理解 BiFPN 的流程

image
来源于项目作者知乎账号,详情请见参考链接

作者还表示,其中有个非常关键点,「鸡贼的官方并没有表示这里是两个独立的 P4_0」。

简而言之,这篇知乎博客非常详细的介绍了各种复现注意事项,细节在此不再一一赘述。笔者认为对各炼金术师有一定参考价值,感兴趣的可以直接查看原文博客。

同时,机器之心对此项目也进行了实测。

项目实测

测试

我们在 P100 GPU,Ubuntu 18.04 系统下对本项目进行了测试。

首先将项目克隆到本地,并切换到相关目录下:

!git clone https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch
import os
os.chdir('Yet-Another-EfficientDet-Pytorch')

安装如下依赖环境:

!pip install pycocotools numpy opencv-python tqdm tensorboard tensorboardX pyyaml
!pip install torch==1.4.0
!pip install torchvision==0.5.0

项目作者为我们提供了用于推断测试的 Python 脚本 efficientdet_test.py,该脚本会读取 weights 文件夹下保存的网络权重,并对 test 文件夹中的图片进行推断,之后将检测结果保存到同一文件夹下。首先,我们使用如下命令下载预训练模型:

!mkdir weights
os.chdir('weights')
!wget https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d0.pth

之后把需要检测的图片放在 test 文件夹下,这里别忘了还要把 efficientdet_test.py 中对应的图像名称修改为我们想要检测图片的名称,运行 efficientdet_test.py 脚本即可检测图片中的物体,输出结果如下:

image

我们先用曾经爆火的共享单车,现如今倒了一大片沦为「共享单车坟场」测试一下效果如何。下图分别为原图与使用本项目的检测结果。

image
image

效果很不错,图片中的人与密密麻麻、横七竖八摆放的共享单车大多都检测了出来。接下来我们用一张国内常见的堵车场景来测试一下,车辆、非机动车、行人交错出现在画面中,可以说是非常复杂的场景了。从检测结果可以看出,基本上所有的行人、车辆、背包、袋子等物体都较好地检测了出来。

image
image

最后当然要在「开挂民族」坐火车的场景下测试一番,密集恐惧症慎入。虽然把旗子检测成了风筝(很多目标检测算法都容易出现这样的问题),但总体来说检测效果可以说是非常惊艳的。它检测出了图片中大部分的人物,和机器之心此前报道过的高精度人脸检测方法-DBFace 的准确率有得一拼。需要注意的是,DBFace 是专用于人脸检测的方法,而本项目实现的是通用物体检测。

image
image

训练

项目作者同时也提供了训练 EfficientDet 相关的代码。我们只需要准备好训练数据集,设置好类似于如下代码所示的训练参数,运行 train.py 即可进行训练。

# create a yml file {your_project_name}.yml under 'projects'folder 
# modify it following 'coco.yml'

# for example
project_name: coco
train_set: train2017
val_set: val2017
num_gpus: 4  # 0 means using cpu, 1-N means using gpus 

# mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

# this is coco anchors, change it if necessary
anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'

# objects from all labels from your dataset with the order from your annotations.
# its index must match your dataset's category_id.
# category_id is one_indexed,
# for example, index of 'car' here is 2, while category_id of is 3
obj_list: ['person', 'bicycle', 'car', ...]

在 coco 数据集上训练代码如下:

 train efficientdet-d0 on coco from scratch 
# with batchsize 12
# This takes time and requires change 
# of hyperparameters every few hours.
# If you have months to kill, do it. 
# It's not like someone going to achieve
# better score than the one in the paper.
# The first few epoches will be rather unstable,
# it's quite normal when you train from scratch.

python train.py -c 0 --batch_size 12

在自定义数据集上训练:

train efficientdet-d1 on a custom dataset 
# with batchsize 8 and learning rate 1e-5

python train.py -c 1 --batch_size 8 --lr 1e-5

项目作者强烈推荐在预训练的权重上对网络进行训练:

train efficientdet-d2 on a custom dataset with pretrained weights
# with batchsize 8 and learning rate 1e-5 for 10 epoches

python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10 \
 --load_weights /path/to/your/weights/efficientdet-d2.pth

# with a coco-pretrained, you can even freeze the backbone and train heads only
# to speed up training and help convergence.

python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10 \
 --load_weights /path/to/your/weights/efficientdet-d2.pth \
 --head_only True

项目作者知乎文章:https://zhuanlan.zhihu.com/p/129016081

【云栖号在线课堂】每天都有产品技术专家分享!
课程地址:https://yqh.aliyun.com/live

立即加入社群,与专家面对面,及时了解课程最新动态!
【云栖号在线课堂 社群】https://c.tb.cn/F3.Z8gvnK

原文发布时间:2020-04-14
本文来自:“OpenCV学堂”,了解相关信息可以关注“OpenCV学堂

相关文章
【教程】加速访问和下载github项目,原来替换一个域名就可以加速了
【教程】加速访问和下载github项目,原来替换一个域名就可以加速了
|
3天前
|
Dubbo NoSQL Java
GitHub置顶46k星的互联网大厂Java工程师进阶题,腾讯官方首秀!
大家好,最近有不少小伙伴在后台留言,又得准备面试了,不知道从何下手! 可以肯定的是,不管我们要学习提升还是涨薪跳槽!都要先给自己设立一个小目标,然后想着自己的目标努力奋斗就是了。
|
3天前
|
开发工具 git
github下载项目
github下载项目
28 0
|
6天前
|
消息中间件 前端开发 Java
GitHub标星30k!基于Spring MVC Mybatis分布式开发系统-zheng项目
zheng项目不仅仅是一个开发架构,而是努力打造一套从 前端模板 - 基础框架 - 分布式架构 - 开源项目 - 持续集成 - 自动化部署 - 系统监测 - 无缝升级 的全方位J2EE企业级开发解决方案。
|
6天前
|
开发框架 前端开发 Java
GitHub首次开源标星20k+项目:Guns-现代化主流Java应用开发框架
Guns是一个现代化的Java应用开发框架,基于主流技术Spring Boot2 + Vue3,Guns的核心理念是提高开发人员开发效率,降低企业信息化系统的开发成本
|
18天前
|
SpringCloudAlibaba Java Nacos
「Spring Cloud Alibaba官方手册」首发爆火,Github上标星243k
几年前 Dubbo被 SpringCloud所取代,相同的剧本,可惜阿里巴巴和 Spring社区都是巨头,巨头之间战斗要考虑很多,于是它们想到了合作, SpringCloud与alibaba相结合,技术上有人负责更新新的组件,也还可以继续使用 Spring社区的技术。于是 SpringCloudAlibaba诞生了。
|
21天前
|
前端开发 NoSQL Java
GitHub官方出手,一针见血!Spring Boot趣味实战手册来袭(彩版)
本书内容极其丰富,不仅涵盖了Spring MVC、MyBatis Plus、SpringData JPA、Spring Security、Quartz等主流框架,整合了MySQL、Druid、Redis、RabbitMQ、Elasticsearch等互联网常用技术与中间件,还涉及单元测试、异常处理、日志、Swagger等技术细节,以及AOP、IOC、自动配置、数据库事务、分布式锁等硬核知识。本书从初始化到部署、监控,实现了软件全生命周期一站式打包解决。
|
4月前
|
人工智能 JavaScript 程序员
GitHub原生AI代码生成工具Copilot,官方支持Visual Studio 2022
GitHub原生AI代码生成工具Copilot,官方支持Visual Studio 2022
171 0
|
4月前
|
JavaScript 前端开发 数据可视化
再不用怕Markdown中的绘图了,GitHub官方支持Mermaid图表绘制工具
再不用怕Markdown中的绘图了,GitHub官方支持Mermaid图表绘制工具
|
7月前
|
机器学习/深度学习 Web App开发 搜索推荐
GitHub标星23k+,从零开始的深度学习实用教程 | PyTorch官方推荐
背景基础部分,除了有Python指南、笔记本用法,以及Numpy、Pandas这些库的用法,还有线性规划、逻辑规划、随机森林、k-means聚类这些机器学习的基本技术。
GitHub标星23k+,从零开始的深度学习实用教程 | PyTorch官方推荐
相关产品
机器翻译
推荐文章
更多