使用用测试时数据增强(TTA)提高预测结果(下)

简介: 使用用测试时数据增强(TTA)提高预测结果(下)

然后,evaluate_model()函数可以被更新为调用tta_evaluate_model()以获得模型精度分数。

#fitandevaluateadefinedmodeldefevaluate_model(model, trainX, trainY, testX, testY):
#fitmodelmodel.fit(trainX, trainY, epochs=3, batch_size=128, verbose=0)
#evaluatemodelusingttaacc=tta_evaluate_model(model, testX, testY)
returnacc

将所有这些结合在一起,下面列出了使用TTA的CNN for CIFAR-10重复评估的完整示例

#cnnmodelforthecifar10problemwithtest-timeaugmentationimportnumpyfromnumpyimportargmaxfromnumpyimportmeanfromnumpyimportstdfromnumpyimportexpand_dimsfromsklearn.metricsimportaccuracy_scorefromkeras.datasets.cifar10importload_datafromkeras.utilsimportto_categoricalfromkeras.preprocessing.imageimportImageDataGeneratorfromkeras.modelsimportSequentialfromkeras.layersimportConv2Dfromkeras.layersimportMaxPooling2Dfromkeras.layersimportDensefromkeras.layersimportFlattenfromkeras.layersimportBatchNormalization#loadandreturnthecifar10datasetreadyformodelingdefload_dataset():
#loaddataset (trainX, trainY), (testX, testY) =load_data()
#normalizepixelvaluestrainX=trainX.astype('float32') /255testX=testX.astype('float32') /255#onehotencodetargetvaluestrainY=to_categorical(trainY)
testY=to_categorical(testY)
returntrainX, trainY, testX, testY#definethecnnmodelforthecifar10datasetdefdefine_model():
#definemodelmodel=Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_uniform', input_shape=(32, 32, 3)))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_uniform'))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(BatchNormalization())
model.add(Dense(10, activation='softmax'))
#compilemodelmodel.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
returnmodel#makeapredictionusingtest-timeaugmentationdeftta_prediction(datagen, model, image, n_examples):
#convertimageintodatasetsamples=expand_dims(image, 0)
#prepareiteratorit=datagen.flow(samples, batch_size=n_examples)
#makepredictionsforeachaugmentedimageyhats=model.predict_generator(it, steps=n_examples, verbose=0)
#sumacrosspredictionssummed=numpy.sum(yhats, axis=0)
#argmaxacrossclassesreturnargmax(summed)
#evaluateamodelonadatasetusingtest-timeaugmentationdeftta_evaluate_model(model, testX, testY):
#configureimagedataaugmentationdatagen=ImageDataGenerator(horizontal_flip=True)
#definethenumberofaugmentedimagestogeneratepertestsetimagen_examples_per_image=7yhats=list()
foriinrange(len(testX)):
#makeaugmentedpredictionyhat=tta_prediction(datagen, model, testX[i], n_examples_per_image)
#storeforevaluationyhats.append(yhat)
#calculateaccuracytestY_labels=argmax(testY, axis=1)
acc=accuracy_score(testY_labels, yhats)
returnacc#fitandevaluateadefinedmodeldefevaluate_model(model, trainX, trainY, testX, testY):
#fitmodelmodel.fit(trainX, trainY, epochs=3, batch_size=128, verbose=0)
#evaluatemodelusingttaacc=tta_evaluate_model(model, testX, testY)
returnacc#repeatedlyevaluatemodel, returndistributionofscoresdefrepeated_evaluation(trainX, trainY, testX, testY, repeats=10):
scores=list()
for_inrange(repeats):
#definemodelmodel=define_model()
#fitandevaluatemodelaccuracy=evaluate_model(model, trainX, trainY, testX, testY)
#storescorescores.append(accuracy)
print('> %.3f'%accuracy)
returnscores#loaddatasettrainX, trainY, testX, testY=load_dataset()
#evaluatemodelscores=repeated_evaluation(trainX, trainY, testX, testY)
#summarizeresultprint('Accuracy: %.3f (%.3f)'% (mean(scores), std(scores)))

考虑到重复评估和用于评估每个模型的较慢的手动测试时间增加,运行这个示例可能需要一些时间。

在这种情况下,我们可以看到性能从没有增加测试时间的测试集上的68.6%提高到增加测试时间的测试集上的69.8%。

>0.719>0.716>0.709>0.694>0.690>0.694>0.680>0.676>0.702>0.704Accuracy: 0.698 (0.013)

TTA如何调优

选择能给模型性能带来最大提升的扩展配置可能是一项挑战。

不仅有许多可选择的扩展方法和每种方法的配置选项,而且在一组配置选项上适合和评估模型的时间可能会花费很长时间,即使适合快速的GPU。

相反,我建议对模型进行一次调整并将其保存到文件中。例如:

#savemodelmodel.save('model.h5')

然后从单独的文件加载模型,并在一个小的验证数据集或测试集的一个小子集上评估不同的测试时间增强方案。

例如:

...
#loadmodelmodel=load_model('model.h5')
#evaluatemodeldatagen=ImageDataGenerator(...)
...

一旦找到了一组能够带来最大提升的扩展选项,您就可以在整个测试集中评估模型,或者像上面那样进行重复的评估实验。

测试时间扩展配置不仅包括ImageDataGenerator的选项,还包括为测试集中每个示例生成平均预测的图像数量。

在上一节中,我使用这种方法来选择测试时间的增加,发现7个示例比3个或5个更好,而且随机缩放和随机移动似乎会降低模型的精度。

记住,如果你也为训练数据集使用图像数据增强,并且这种增强使用一种涉及计算数据集统计数据的像素缩放(例如,你调用datagen.fit()),那么这些相同的统计数据和像素缩放技术也必须在测试时间增强中使用。

总结

在本文章中,您将发现测试时增强可以提高用于图像分类任务的模型的性能。

具体来说,你学会了:

测试时间增广是数据增广技术的应用,通常用于在训练中进行预测。

如何在Keras中从头开始实现测试时间增强。

如何使用测试时间增强来提高卷积神经网络模型在标准图像分类任务中的性能。

原文地址:https://machinelearningmastery.com/how-to-use-test-time-augmentation-to-improve-model-performance-for-image-classification/

目录
相关文章
|
21天前
|
机器学习/深度学习 算法 UED
在数据驱动时代,A/B 测试成为评估机器学习项目不同方案效果的重要方法
在数据驱动时代,A/B 测试成为评估机器学习项目不同方案效果的重要方法。本文介绍 A/B 测试的基本概念、步骤及其在模型评估、算法改进、特征选择和用户体验优化中的应用,同时提供 Python 实现示例,强调其在确保项目性能和用户体验方面的关键作用。
27 6
|
23天前
|
机器学习/深度学习 算法 UED
在数据驱动时代,A/B 测试成为评估机器学习项目效果的重要手段
在数据驱动时代,A/B 测试成为评估机器学习项目效果的重要手段。本文介绍了 A/B 测试的基本概念、步骤及其在模型评估、算法改进、特征选择和用户体验优化中的应用,强调了样本量、随机性和时间因素的重要性,并展示了 Python 在 A/B 测试中的具体应用实例。
26 1
|
2月前
|
存储 测试技术 数据库
数据驱动测试和关键词驱动测试的区别
数据驱动测试 数据驱动测试或 DDT 也被称为参数化测试。
34 1
|
2月前
|
机器学习/深度学习 监控 计算机视觉
目标检测实战(八): 使用YOLOv7完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
本文介绍了如何使用YOLOv7进行目标检测,包括环境搭建、数据集准备、模型训练、验证、测试以及常见错误的解决方法。YOLOv7以其高效性能和准确率在目标检测领域受到关注,适用于自动驾驶、安防监控等场景。文中提供了源码和论文链接,以及详细的步骤说明,适合深度学习实践者参考。
508 0
目标检测实战(八): 使用YOLOv7完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
|
2月前
|
机器学习/深度学习 并行计算 数据可视化
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
这篇文章介绍了如何使用PaddleClas框架完成多标签分类任务,包括数据准备、环境搭建、模型训练、预测、评估等完整流程。
134 0
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
|
2月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
62 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
2月前
|
机器学习/深度学习 XML 并行计算
目标检测实战(七): 使用YOLOX完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
这篇文章介绍了如何使用YOLOX完成图像目标检测任务的完整流程,包括数据准备、模型训练、验证和测试。
221 0
目标检测实战(七): 使用YOLOX完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
|
2月前
|
SQL 消息中间件 大数据
大数据-159 Apache Kylin 构建Cube 准备和测试数据(一)
大数据-159 Apache Kylin 构建Cube 准备和测试数据(一)
72 1
|
2月前
|
SQL 大数据 Apache
大数据-159 Apache Kylin 构建Cube 准备和测试数据(二)
大数据-159 Apache Kylin 构建Cube 准备和测试数据(二)
85 1
|
2月前
|
存储 SQL 分布式计算
大数据-135 - ClickHouse 集群 - 数据类型 实际测试
大数据-135 - ClickHouse 集群 - 数据类型 实际测试
41 0