图像分类模型嵌入flask中开发PythonWeb项目
图像分类是一种常见的计算机视觉任务,它的目的是将输入的图像分配到预定义的类别中,如猫、狗、花等。图像分类模型是一种基于深度学习的模型,它可以利用大量的图像数据来学习图像的特征和类别之间的关系,并且能够对新的图像进行分类。在本博客中,我将介绍如何将深度学习模型嵌入到PythonWeb中,实现图像分类模型的web端实现。具体以西红柿病害识别交流平台的开发为例进行描述。
视频演示地址:https://www.bilibili.com/video/BV11X4y1d7Cj/?spm_id_from=333.999.0.0
一、系统架构
本系统使用flask+mysql+layui+ajax开发,深度学习环境为pytorch+anaconda,数据库存储使用mysql,开发工具为pycharm,前后端不分离项目,但可以实现前后端的动态交互。
前端:使用layui框架搭建一个简洁美观的界面,使用ajax技术实现与后端的同步通信,实现上传图片、显示结果等功能。
后端:使用flask框架搭建一个轻量级的web服务器,使用flask_mysqldb模块连接mysql数据库,使用pytorch框架加载预训练好的图像分类模型,实现接收图片、处理图片、返回结果等功能。
数据库:使用mysql数据库存储用户信息、图片信息、分类结果等数据,方便后期的数据分析和管理。
二、系统功能
本系统主要实现了以下功能:
(1)账号注册
用户首次进入平台需要进行注册。
(2)账号登录
当用户进入平台后,需进入登录界面,用户登录后,方可正常使用其它功能。
(3)信息修改
用户进入个人中心后,可以修改包括头像,用户名,用户密码等个人信息。
(4)账号退出
用户点击退出按钮之后,退出当前登录状态。
(5)图片上传
用户在点击检测之后,可以上传西红柿叶片的图片。
(6)病害识别
用户成功上传西红柿叶片的图片后,可在进行病害的识别及分析,稍作等待可查看结果。
(7)记录查看
用户可随时在平台看到历史上传的西红柿叶片,以及识别结果。
(8)发布论坛
用户点击论坛进入论坛界面,可点击发布论坛。
(9)评论帖子
用户点击论坛进入帖子的详情页面后,可点击发表帖子评论。
三 软件概述
3.1需求分析
西红柿病害识别交流平台是一个专注于西红柿病害识别和交流的网站,用户可以通过该平台上传西红柿照片,进行病害识别和分类,并且可以在论坛中与其他用户交流和分享经验。下面是该平台的需求分析:
用户端需求:
(1)用户可以注册和登录账号。
(2)用户可以上传西红柿照片进行病害识别和分类。
(3)系统可以对上传的照片进行自动识别和分类。
(4)用户可以在识别结果页面查看病害的名称和描述信息。
(5)用户可以在社区中发布病害相关的问题或心得,也可以对其他用户发布的内容进行回复。
技术需求:
(1)需要使用深度学习框架(如PyTorch)对西红柿照片进行训练和分类。
(2)需要使用Web框架(如Flask)搭建用户端和管理端。
(3)需要使用数据库(如MySQL)存储用户信息、照片和社区内容等数据。
(4)需要使用HTML、CSS和JavaScript等技术实现网站的前端交互和设计。
3.2 系统开发环境
(1)硬件环境主要支撑软件的正常运行。下表为该系统的硬件开发环境。
硬件配置 | |
CPU | Intel Core i5-7300HQ |
显卡 | 4G NVIDIA GTX 1050 |
内存 | 16 GB DDR4 2400Mhz |
(2) 软件环境主要进行系统源代码的开发以及调试。下表为该系统的软件开发环境。
软件环境 | |
操作系统 | Window10 64bit |
后端开发语言 | Python3.7 |
开发工具 | PyCharm 2022.3.3 |
后端开发框架 | Flask 2.2.3 |
数据库 | MySQL 5.7 |
数据库可视化 | Navicat Premium 16 |
四 系统操作说明
4.1账号注册
用户进入西红柿病害识别交流平台后需要进行注册才能得到登录的许可,用户需要根据提示信息首先输入注册的邮箱,然后点击发送验证码按钮并进入邮箱查看验证码信息,填写正确的邮箱验证码。接下来进行用户名、密码以及性别的输入,密码需要二次确认,成功注册后会自动跳转进入登录界面。图4-1为注册界面。
图4-1 注册界面
用户成功注册后会将用户信息添加进MySQL数据库的user表中,如图4-2,用户名为“一个人走”,性别为“男”,邮箱为“gzy_personal@163.com”的用户成功被添加进user表中。
图4-2 数据库表
4.2账号登录
用户可在右上角进行登录操作,用户通过输入正确的邮箱、密码以及验证码才可以成功登入系统,其中邮箱必须是正确格式,密码必须大于等于6位数。成功登录后页面自动跳转到首页。图4-3为登录界面。
图4-3 登录界面
4.3个人信息修改
用户登录系统后,右上角由“登录”,“注册”按钮切换为头像以及用户名,系统会根据用户性别生成默认头像,用户名也是一个二级导航栏,可点击用户个人中心、密码修改以及退出登录。图4-4为登录成功后右上角出现的用户信息。
图4-4 用户信息导航栏
通过点击个人中心可进行头像的上传以及用户名的修改,其中邮箱和性别是不可修改的,用户点击上传头像可进入资源管理器进行个人头像的上传,也可以修改用户名。图4-5为用户个人中心界面。
图4-5 用户个人中心
如图4-6所示,用户的默认头像被更改,用户名由“一个人走”修改为“小小志”。
图4-6 成功修改界面
通过点击修改密码可进行用户密码的修改,用户需要输入原始的登录密码、新密码以及图片验证码,其中原始密码必须是正确的登录密码,新密码需要二次确认,图片验证码必须与图片中的验证码一致才可以成功修改密码。图4-7为用户修改密码页面。
图4-7 修改密码
4.4 主页面
用户进入页面或者成功登录即可跳转到严重度分级平台的主页面,主页显示了自动轮播的4个轮播图和平台提供的一些西红柿病害资料。图4-8为主页面。
图4-8 主页面
4.5病害识别
用户可以在上传页面进行西红柿叶片的图片上传,点击选择文件选择本地图片资源,然后点击识别按钮即可进行西红柿病害的识别,图4-9为识别界面。
图4-9 识别界面
用户只需稍等即可看到西红柿叶片病害识结果。图4-10为结果界面,其中根据上传的图片可以看到上传的原图,也可以得到识别出的病害分类、识别准确率、识别时间和对应的建议。
图4-10 结果界面
4.6记录查看
用户点击导航栏处的记录即可查看该用户曾经识别的病害图像记录。其中,西红柿病害识别交流平台能够根据用户上传的图片信息进行筛选,如图像一致可避免记录中的重复显示。图4-11为记录查看界面。
图4-11 记录查看界面
4.7论坛功能
用户点击导航栏上的论坛可以进入论坛界面,图4-12为论坛界面。
图4-12 论坛界面
在论坛界面中可以看到其他用户发表的论坛记录及其内容,点击“+”按钮按钮即可进入论坛发布的界面,图4-13为发布论坛界面,用户需要输入标题以及内容,然后点击发布按钮即可成功发布论坛。其中,西红柿病害识别交流平台会检测用户输入的情况,标题必须大于5个字符,内容必须大于10个字符,否则将不能成功发布论坛。
图4-13 发布论坛界面
用户在论坛界面中可以看到其他用户发表的论坛记录及其内容,点击帖子的标题即可进入该帖子的详情界面。用户可在详情页面看到该帖子的主要内容,发布者信息、发布时间以及其他用户的评论,用户自己也可以填写并发表自己的评论,输入内容后点击评论按钮即可成功发布评论。图4-14为帖子详情界面。
图4-14 帖子详情界面
五 部分代码展示:
代码目录结构:
model.py数据库表模型文件代码
from exts import db from datetime import datetime # 用户表,记录用户账号信息 class UserModel(db.Model): __tablename__ = 'user' id = db.Column(db.Integer, primary_key=True, autoincrement=True) username = db.Column(db.String(100), nullable=False) password = db.Column(db.String(200), nullable=False) email = db.Column(db.String(100), nullable=False, unique=True) # 邮箱唯一 join_time = db.Column(db.DateTime, default=datetime.now()) # 用户注册表,记录用户注册时的信息 class EmailCaptchaModel(db.Model): __tablename__ = 'email_captcha' id = db.Column(db.Integer, primary_key=True, autoincrement=True) email = db.Column(db.String(100), nullable=False) captcha = db.Column(db.String(100), nullable=False) # 用户上传表,记录用户在网页中上传的图片 class ImageUploadModel(db.Model): __tablename__ = 'image_upload' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) img_name = db.Column(db.String(100), nullable=False) img_format = db.Column(db.String(100), nullable=False) upload_time = db.Column(db.DateTime, default=datetime.now()) # 结果记录表,记录用户在网页中上传图片的结果 class ImageRecordModel(db.Model): __tablename__ = 'image_record' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) img_id = db.Column(db.Integer, db.ForeignKey('image_upload.id'), nullable=False) img_acc = db.Column(db.String(50), nullable=False) img_class = db.Column(db.String(50), nullable=False) ident_time = db.Column(db.DateTime, default=datetime.now()) record=db.relationship('UserModel', backref='user_record', uselist=False)
predict_pth.py预测单张图片的分类并返回结果
import json import os import numpy as np import torch from matplotlib import pyplot as plt from torchvision import transforms from PIL import Image from model.mobilenetv2 import MobileNetV2 as create_model from model.utils import GradCAM, show_cam_on_image, center_crop_img def pred(img_path): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"using {device} device.") num_classes = 4 img_size = 224 data_transform = transforms.Compose( [transforms.Resize(int(img_size * 1.14)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) img = Image.open(img_path) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = 'model/class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) json_file = open(json_path, "r", encoding='utf-8') class_indict = json.load(json_file) # create model model = create_model(num_classes=num_classes).to(device) # load model weights model_weight_path = "model/bestmodel.pth" model.load_state_dict(torch.load(model_weight_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() # print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], # predict[predict_cla].numpy()) class_prob=predict[predict_cla].numpy() class_name=class_indict[str(predict_cla)] return class_prob,class_name if __name__=="__main__": img_path = "static/img/1.jpg" class_prob,class_name= pred(img_path) print(class_name,class_prob)
参考资料
《Python编程从入门到实践》
《PyTorch 官方文档》
《PyTorch 图像分类教程》
《Flask官方文档》
《Flask Web开发实战》
《图像处理与计算机视觉》
具体代码及实现教程可私信。