import os
import json
import torch
from torchvision import transforms
from PIL import Image
import torchvision
from tqdm import tqdm
# 加载运算设备
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 每批次数据数量
batch_size=100
# 加载需要预测的数据,train需要转化为False
test_dataset = torchvision.datasets.CIFAR10(root='./cifar10',
train=False,
download=False,
transform=data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
shuffle=False)
# 加载预测结果与真实分类的映射
json_path='./class_indices.json'
json_file=open(json_path,'r')
class_indict=json.load(json_file)
# 构建网络
model=resnet34(num_classes=10).to(device)
# 加载模型训练好的参数
weights_path='./resNet34_cifar10.pth'
model.load_state_dict(torch.load(weights_path,map_location=device))
# 开启验证模式,进行预测
acc=0
model.eval()
test_bar=tqdm(test_loader)
for data in test_bar:
with torch.no_grad():
images,labels=data
output=model(images.to(device)).cpu()
y_pred=torch.max(output,dim=1)[1]
acc+=torch.eq(y_pred,labels.to(device)).sum().item()
# 打印模型预测的准确率
print(acc/len(test_dataset))