在计算机视觉领域里,有3个最受欢迎且影响非常大的学术竞赛:ImageNet ILSVRC(大规模视觉识别挑战赛),PASCAL VOC(关于模式分析,统计建模和计算学习的研究)和微软COCO图像识别大赛。这些比赛大大地推动了在计算机视觉研究中的多项发明和创新,其中很多都是免费开源的。
博客Deep Learning Sandbox作者Greg Chu打算通过一篇文章,教你用Keras和TensorFlow,实现对ImageNet数据集中日常物体的识别。
量子位翻译了这篇文章:
你想识别什么?
看看ILSVRC竞赛中包含的物体对象。如果你要研究的物体对象是该列表1001个对象中的一个,运气真好,可以获得大量该类别图像数据!以下是这个数据集包含的部分类别:
狗 | 熊 | 椅子 |
---|---|---|
汽车 | 键盘 | 箱子 |
婴儿床 | 旗杆 | iPod播放器 |
轮船 | 面包车 | 项链 |
降落伞 | 枕头 | 桌子 |
钱包 | 球拍 | 步枪 |
校车 | 萨克斯管 | 足球 |
袜子 | 舞台 | 火炉 |
火把 | 吸尘器 | 自动售货机 |
眼镜 | 红绿灯 | 菜肴 |
盘子 | 西兰花 | 红酒 |
△ 表1 ImageNet ILSVRC的类别摘录
完整类别列表见:https://gist.github.com/gregchu/134677e041cd78639fea84e3e619415b
如果你研究的物体对象不在该列表中,或者像医学图像分析中具有多种差异较大的背景,遇到这些情况该怎么办?可以借助迁移学习(transfer learning)和微调(fine-tuning),我们以后再另外写文章讲。
图像识别
图像识别,或者说物体识别是什么?它回答了一个问题:“这张图像中描绘了哪几个物体对象?”如果你研究的是基于图像内容进行标记,确定盘子上的食物类型,对癌症患者或非癌症患者的医学图像进行分类,以及更多的实际应用,那么就能用到图像识别。
Keras和TensorFlow
Keras是一个高级神经网络库,能够作为一种简单好用的抽象层,接入到数值计算库TensorFlow中。另外,它可以通过其keras.applications
模块获取在ILSVRC竞赛中获胜的多个卷积网络模型,如由Microsoft Research开发的ResNet50网络和由Google Research开发的InceptionV3网络,这一切都是免费和开源的。具体安装参照以下说明进行操作:
Keras安装:https://keras.io/#installation
TensorFlow安装:https://www.tensorflow.org/install/
实现过程
我们的最终目标是编写一个简单的python程序,只需要输入本地图像文件的路径或是图像的URL链接就能实现物体识别。
以下是输入非洲大象照片的示例:
1. python classify.py --image African_Bush_Elephant.jpg
2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg
输入:
输出将如下所示:
△ 该图像最可能的前3种预测类别及其相应概率
预测功能
我们接下来要载入ResNet50网络模型。首先,要加载keras.preprocessing
和keras.applications.resnet50
模块,并使用在ImageNet ILSVRC比赛中已经训练好的权重。
想了解ResNet50的原理,可以阅读论文《基于深度残差网络的图像识别》。地址:https://arxiv.org/pdf/1512.03385.pdf
import numpy as np
from keras.preprocessing import image
from keras.applications.resnet50
import ResNet50, preprocess_input, decode_predictions
model = ResNet50(weights='imagenet')
接下来定义一个预测函数:
def predict(model, img, target_size, top_n=3):
"""Run model prediction on image
Args:
model: keras model
img: PIL format image
target_size: (width, height) tuple
top_n: # of top predictions to return
Returns:
list of predicted labels and their probabilities
"""
if img.size != target_size:
img = img.resize(target_size)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
return decode_predictions(preds, top=top_n)[0]
在使用ResNet50网络结构时需要注意,输入大小target_size
必须等于(224,224)
。许多CNN网络结构具有固定的输入大小,ResNet50正是其中之一,作者将输入大小定为(224,224)
。
image.img_to_array
:将PIL格式的图像转换为numpy数组。
np.expand_dims
:将我们的(3,224,224)
大小的图像转换为(1,3,224,224)
。因为model.predict
函数需要4维数组作为输入,其中第4维为每批预测图像的数量。这也就是说,我们可以一次性分类多个图像。
preprocess_input
:使用训练数据集中的平均通道值对图像数据进行零值处理,即使得图像所有点的和为0。这是非常重要的步骤,如果跳过,将大大影响实际预测效果。这个步骤称为数据归一化。
model.predict
:对我们的数据分批处理并返回预测值。
decode_predictions
:采用与model.predict
函数相同的编码标签,并从ImageNet ILSVRC集返回可读的标签。
keras.applications
模块还提供4种结构:ResNet50、InceptionV3、VGG16、VGG19和XCeption,你可以用其中任何一种替换ResNet50。更多信息可以参考https://keras.io/applications/。
绘图
我们可以使用matplotlib
函数库将预测结果做成柱状图,如下所示:
def plot_preds(image, preds):
"""Displays image and the top-n predicted probabilities
in a bar graph
Args:
image: PIL image
preds: list of predicted labels and their probabilities
"""
#image
plt.imshow(image)
plt.axis('off') #bar graph
plt.figure()
order = list(reversed(range(len(preds))))
bar_preds = [pr[2] for pr in preds]
labels = (pr[1] for pr in preds)
plt.barh(order, bar_preds, alpha=0.5)
plt.yticks(order, labels)
plt.xlabel('Probability')
plt.xlim(0, 1.01)
plt.tight_layout()
plt.show()
主体部分
为了实现以下从网络中加载图片的功能:
1. python classify.py --image African_Bush_Elephant.jpg
2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg
我们将定义主函数如下:
if __name__=="__main__":
a = argparse.ArgumentParser()
a.add_argument("--image",
help="path to image")
a.add_argument("--image_url",
help="url to image")
args = a.parse_args()
if args.image is None and args.image_url is None:
a.print_help()
sys.exit(1)
if args.image is not None:
img = Image.open(args.image)
print_preds(predict(model, img, target_size))
if args.image_url is not None:
response = requests.get(args.image_url)
img = Image.open(BytesIO(response.content))
print_preds(predict(model, img, target_size))
其中在写入image_url
功能后,用python中的Requests库就能很容易地从URL链接中下载图像。
完工
将上述代码组合起来,你就创建了一个图像识别系统。项目的完整程序和示例图像请查看GitHub链接:
https://github.com/DeepLearningSandbox/DeepLearningSandbox/tree/master/image_recognition