图像识别是将图像内容作为一个对象来识别其类型。使用OpenCV中的深度学习预训练模型进行图像识别的基本步骤如下。
(1)从配置文件和预训练模型文件中加载模型。
(2)将图像文件处理为块数据(blob)。
(3)将图像文件的块数据设置为模型的输入。
(4)执行预测。
(5)处理预测结果。
1.基于AlexNet和Caffe模型的图像识别
AlexNet由2012年ImageNet竞赛冠军获得者辛顿(Hinton)和他的学生阿莱克斯·克里泽夫斯基(Alex Krizhevsky)设计,其网络结构包含了5层卷积神经网络(Convolutional Neural Network,CNN),3层全连接网络,采用GPU来加速计算。在处理图像时,AlexNet使用的图像块大小为224×224。
Caffe的全称为快速特征嵌入的卷积结构(Convolutional Architecture for Fast Feature Embedding),是一个兼具表达性、速度和思维模块化的深度学习框架。Caffe由伯克利人工智能研究小组和伯克利视觉和学习中心开发。Caffe内核用C++实现,提供了Python和Matlab等接口。
下面的代码使用基于AlexNet和Caffe的预训练模型进行图像识别。
使用基于AlexNet和Caffe模型的图像识别
import cv2
import numpy as np
from matplotlib import pyplot as plt
from PIL import ImageFont, ImageDraw, Image
读入文本文件中的类别名称,共1000种类别,每行为一个类别,第11个字符开始为名称
基本格式如下。
n01440764 tench, Tinca tinca
n01443537 goldfish, Carassius auratus
……
file=open('classes.txt')
names=[r.strip() for r in file.readlines()]
file.close()
classes = [r[10:] for r in names] #获取每个类别的名称
从文件中载入Caffe模型
net = cv2.dnn.readNetFromCaffe("AlexNet_deploy.txt", "AlexNet_CaffeModel.dat")
image = cv2.imread("building.jpg") #打开图像,用于识别分类
创建图像blob数据,大小(224,224),颜色通道的均值缩减比例因子(104, 117, 123)
blob = cv2.dnn.blobFromImage(image, 1, (224,224), (104, 117, 123))
net.setInput(blob) #将图像blob数据作为神经网络输入
执行预测,返回结果是一个1×1000的数组,按顺序对应1000种类别的可信度
result = net.forward()
ptime, x = net.getPerfProfile() #获得完成预测时间
print('完成预测时间: %.2f ms' % (ptime * 1000.0 / cv2.getTickFrequency()))
sorted_ret = np.argsort(result[0]) #将预测结果按可信度高低排序
top5 = sorted_ret[::-1][:5] #获得排名前5的预测结果
print(top5)
ctext = "类别: "+classes[top5[0]]
ptext = "可信度: {:.2%}".format(result[0][top5[0]])
输出排名前5的预测结果
for (index, idx) in enumerate(top5):
print("{}. 类别: {}, 可信度: {:.2%}".format(index + 1, classes[idx], result[0][idx]))
在图像中输出排名第1的预测结果
fontpath = "STSONG.TTF"
font = ImageFont.truetype(fontpath,80) #载入中文字体,设置字号
img_pil = Image.fromarray(image)
draw = ImageDraw.Draw(img_pil)
draw.text((10, 10), ctext, font = font,fill=(0,0,255)) #绘制文字
draw.text((10,100), ptext, font = font,fill=(0,0,255))
img = np.array(img_pil)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
plt.imshow(img)
plt.axis('off')
plt.show() #显示图像