Caffe:使用 classify.py 批量对图片分类

简介: 一般使用 Caffe 训练完网络后,会用 `test.bin` 来测试一下网络的精度,然后还能用 `classification.bin` 来用网络对图片进行单张的分类,但是一张一张的分,效率很低,所以我改写了 `classify.py` 文件,使其读取 test.txt 文件批量分类,输出具体哪一张图片分错了。

一般使用 Caffe 训练完网络后,会用 test.bin 来测试一下网络的精度,然后还能用 classification.bin 来用网络对图片进行单张的分类,但是一张一张的分,效率很低,所以我改写了 classify.py 文件,使其读取 test.txt 文件批量分类,输出具体哪一张图片分错了。

代码如下:

# copyright (c) strongnine

import caffe
import sys
import os
import numpy as np
 
caffe_root = '/path/to/your/caffe/' # 指定 caffe 的路径
sys.path.insert(0,caffe_root+'python')
 
caffe.set_mode_gpu()
 
deploy = caffe_root+'models/bvlc_alexnet/deploy.prototxt' ##
caffe_model = caffe_root+'model/outputs/caffe_alexnet_train_iter_450000.caffemodel' ## 

labels_name = caffe_root+'data/alexnet/synset_words.txt'
labels = np.loadtxt(labels_name, str, delimiter='\t')
for i in range(len(labels)):
    exec(labels[i] + "=0")
right = 0
false = 0
mean_file = caffe_root+'data/alexnet/train_mean.npy' # 由 imagenet_mean.binaryproto 转换来
net = caffe.Net(deploy, caffe_model, caffe.TEST)

transformer=caffe.io.Transformer({'data':net.blobs['data'].data.shape})
transformer.set_transpose('data',(2,0,1))
transformer.set_mean('data',np.load(mean_file).mean(1).mean(1))
transformer.set_raw_scale('data',255)
transformer.set_channel_swap('data',(2,1,0))


test_file = open(caffe_root+'data/alexnet/test.txt', 'r')
test_data = test_file.readlines()

log = open(caffe_root+'data/alexnet/log/classify_log.log', 'w')
image_road = '/your/image/path/'

for line in test_data:
    split = line.split(' ')
    image = caffe.io.load_image(image_road + split[0])
    net.blobs['data'].data[...]=transformer.preprocess('data',image)
 
    out = net.forward()
 
    prob = net.blobs['prob'].data[0].flatten()
    top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1]
    
    log.write(split[0] + ' ' + split[1][0] + ' ' + str(top_k[0]))

    if str(top_k[0]) == split[1][0]:
        right += 1
        log.write(' right\n')
    else:
        false += 1
        log.write(' false\n')

print(right)
print(false)
print(right/float(right + false))

运行完成后会输出分类正确的图片数量,和分类错误的图片数量,以及所有的正确率。

生成完查看 log 文件:

...
cat_01.jpg 0 0 right
cat_02.jpg 0 0 right
person_04.jpg 1 1 right
person_05.jpg 1 0 false
...

第一个数字为标签类别,第二个数字为分类类别。

目录
相关文章
|
18天前
|
XML 计算机视觉 数据格式
数据集学习笔记(四):VOC转COCO数据集并据txt中图片的名字批量提取对应的图片并保存到另一个文件夹
这篇文章介绍了如何将VOC数据集转换为COCO数据集的格式,并通过Python脚本根据txt文件中列出的图片名称批量提取对应的图片并保存到另一个文件夹。
15 3
|
6月前
|
PyTorch 算法框架/工具 异构计算
pytorch 模型保存与加载
pytorch 模型保存与加载
42 0
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow的保存与加载模型
【4月更文挑战第17天】本文介绍了TensorFlow中模型的保存与加载。保存模型能节省训练时间,便于部署和复用。在TensorFlow中,可使用`save_model_to_hdf5`保存模型结构,`save_weights`保存权重,或转换为SavedModel格式。加载时,通过`load_model`恢复结构,`load_weights`加载权重。注意模型结构一致性、环境依赖及自定义层的兼容性问题。正确保存和加载能有效利用模型资源,提升效率和准确性。
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
|
6月前
|
机器学习/深度学习 计算机视觉 Python
批量demo推理图片脚本
批量demo推理图片脚本
|
人工智能 数据可视化 TensorFlow
从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
|
6月前
yolov5--datasets.py --v5.0版本-数据集加载 最新代码详细解释2021-7-5更新
yolov5--datasets.py --v5.0版本-数据集加载 最新代码详细解释2021-7-5更新
269 0
|
机器学习/深度学习 数据可视化 Java
TensorFlow 高级技巧:自定义模型保存、加载和分布式训练
本篇文章将涵盖 TensorFlow 的高级应用,包括如何自定义模型的保存和加载过程,以及如何进行分布式训练。
|
机器学习/深度学习 存储 人工智能
从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)
从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)
从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)
|
PyTorch 算法框架/工具 计算机视觉