导读
图像识别任务是人工智能计算机视觉领域一个重要的子任务,本篇文章将通过使用一个预训练模型来帮助读者快速上手图像识别任务,对应的文件可通过关注文章末尾的公众号领取
本篇文章需要一定人工智能基础,不了解的可从博主其他人工智能专栏进行学习
本次介绍的模型是resnet模型
可以将本篇博文当作notebook来阅读,也方便读者进行运行
模型配置
库的导入
首先导入本次项目所需要的库,torchvision是一个计算机视觉库,里面有很多相关模型
from torchvision import models from torchvision import transforms
模型初始化
接着创建模型对象
resnet = models.resnet101(pretrained=True)
图片处理器
创建一个预处理器,用于将不同格式的图片都转化为模型需要的输入格式
preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )])
图片处理
导入图片处理库
from PIL import Image import matplotlib.pyplot as plt
导入图片
将your_path替换为自己的图片路径
img = Image.open("/kaggle/input/orange/orange.jpg")
查看图片通道数
运行以下代码查看图片通道数
print(len(img.getbands()))
如果图片通道数是4,就将它转化为RGB图片(通道数为3),因为图片预处理器处理的是3通道图片
img= img.convert("RGB")
处理图片
将图片放入图片处理器
img_t = preprocess(img)
导入torch库
import torch
扩充维度
拓展一个维度用作训练数据
batch_t = torch.unsqueeze(img_t, 0)
模型训练与评估
模型初始化
这部分代码将模型切换为评估模式,表示本次运行仅使用模型,不进行训练,再将图片输入模型,获取结果
resnet.eval() out = resnet(batch_t)
导入标签
这段代码是模型对应输出的标签,即判别结果,标签文件可关注文章末尾公众号领取,下载后记得将路径替换为自己的路径
with open("/kaggle/input/pytorch/dlwpt-code-master/data/p1ch2/imagenet_classes.txt") as f: labels = [line.strip() for line in f.readlines()]
因为模型输出的是一些浮点数,我们获取最大值的索引,并在标签中搜索则可以得到最终结果
此段代码输出对应的结果和模型判断结果正确的概率
_, index = torch.max(out, 1) percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100 print(labels[index[0]], percentage[index[0]].item())
测试
以下是博主运行的图片和结果
输入了一个橘子图片
模型表示有97%的概率是橘子
感谢阅读,觉得有用的话就订阅下《AI模型分享》专栏吧,有错误也欢迎指出