基于Paddlex超简单的猫狗分类

简介: 基于Paddlex超简单的猫狗分类

可爱的小猫咪谁不爱呢,我也不例外,但是要我说出什么种类,那就是为难我了,我什么种类都不知道,纯属于喜欢,喜欢看,从未养过,期待养只猫。。。。。。


项目简介:


猫的十二分类大家以前都做过,先在我想快速得到分类模型,只想要结果。。。

PaddleHub可以便捷地获取PaddlePaddle生态下的预训练模型,完成模型的管理和一键预测。配合使用Fine-tune API,可以基于大规模预训练模型快速完成迁移学习,让预训练模型能更好地服务于用户特定场景的应用。。。看到这里何不用PaddleHub呢???

原作采用ResNet101为骨架的深度神经网络的猫咪图像分类模型,对猫咪的图像进行分类和目标识别,图像分类识别准确度最高可以达到94%。 我想试试能不能再进一步。。。

image.pngimage.pngimage.pngimage.png


0.环境设置


更改环境为PaddlePaddle2.0 RC1,PaddleHub 2.0.0b2

#导入一些图像处理的包
%cd /home/aistudio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os, shutil, cv2, random
%matplotlib inline
/home/aistudio


1.数据处理


!mkdir dataset
!unzip data/data10954/cat_12_train.zip -d dataset/
!unzip data/data10954/cat_12_test.zip -d dataset/
!cp data/data10954/train_list.txt dataset/
%cd dataset/
/home/aistudio/dataset


图片数据按目录整理


import pandas as pd
import os
import shutil
def mkdir():
    for i in range(12):
        os.mkdir(str(i))
if __name__ == "__main__":
    data = pd.read_csv("train_list.txt", sep="  ", header=None)
    mkdir()
    for i, r in data.iterrows():
        print(os.path.split(r[0]))
        old_file = r[0]
        new_file = os.path.join(str(r[1]), os.path.split(r[0])[-1])
        shutil.move(old_file, new_file)
    print("File resort finished!")
#导入需要的包
import os
import random
import json
import cv2
import numpy as np
from PIL import Image
import paddle
import matplotlib.pyplot as plt
## 转换4通道为3通道
def proc_img(src):
    for root, dirs, files in os.walk(src):
        if '__MACOSX' in root:continue
        for file in files:            
            src=os.path.join(root,file)
            img=Image.open(src)
            ## 转换4通道为3通道
            if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                    img.save(src)            
if __name__=='__main__':
    proc_img("dataset")
!pip install paddlex
%cd ~
/home/aistudio
!rm dataset/train_list.txt
!rmdir dataset/cat_12_train/
!mv dataset/cat_12_test/ ~
!paddlex --split_dataset --format ImageNet --dataset_dir dataset/ --val_value 0.2 --test_value 0.1
!pip install imgaug
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from paddlex.cls import transforms
import paddlex as pdx
train_transforms = transforms.Compose([
    transforms.ResizeByShort(short_size=224),
    transforms.RandomCrop(crop_size=64), transforms.RandomHorizontalFlip(),
    transforms.Normalize()
])
eval_transforms = transforms.Compose([
    transforms.ResizeByShort(short_size=224),
    transforms.CenterCrop(crop_size=64), transforms.Normalize()
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-imagenet
train_dataset = pdx.datasets.ImageNet(
    data_dir='/home/aistudio/dataset/',
    file_list='/home/aistudio/train_list.txt',
    label_list='/home/aistudio/labels.txt',
    transforms=train_transforms,
    shuffle=True)
eval_dataset = pdx.datasets.ImageNet(
    data_dir='/home/aistudio/dataset/',
    file_list='/home/aistudio/val_list.txt',
    label_list='/home/aistudio/labels.txt',
    transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
# model = pdx.cls.MobileNetV3_small_ssld(num_classes=len(train_dataset.labels))
model = pdx.cls.ResNet101_vd_ssld(num_classes=len(train_dataset.labels))
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/classification.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model.train(
    num_epochs=200,
    train_dataset=train_dataset,
    train_batch_size=256,
    eval_dataset=eval_dataset,
    lr_decay_epochs=[4, 6, 8],
    learning_rate=0.05,
    save_dir='output/ResNet101_vd_ssld')


目录
相关文章
|
3月前
|
消息中间件 网络安全 数据安全/隐私保护
我用 Docker 部署 RabbitMQ 踩了 3 个大坑,10 分钟搞定的记录
部署RabbitMQ踩坑记:从Docker安装、镜像拉取到容器启动,亲历5大常见陷阱。分享一键脚本、推荐镜像标签、关键参数配置及Cookie权限修复方案,结合文档避坑指南,助你10分钟快速稳定部署。
616 5
|
人工智能 IDE API
AI驱动的开发者工具:打造沉浸式API集成体验
本文介绍了阿里云在过去十年中为开发者提供的API服务演变。内容分为两大部分:一是从零开始使用API的用户旅程,涵盖API的发现、调试与集成;二是回顾阿里云过去十年为开发者提供的服务及发展历程。文中详细描述了API从最初的手写SDK到自动化生成SDK的变化,以及通过API Explorer、IDE插件和AI助手等工具提升开发者体验的过程。这些工具和服务旨在帮助开发者更高效地使用API,减少配置和调试的复杂性,提供一站式的解决方案。
|
9月前
|
安全 API 开发者
HarmonyOS NEXT《ArkTS渲染控制完全指南:条件与循环渲染深度解析》
本文深入解析ArkTS条件渲染与循环渲染核心技术,涵盖`if/else`和`ForEach`的使用方法、动态更新机制及性能优化策略。通过20+实战案例,如数据增删、拖拽排序、点赞交互等,结合骨架屏加载、动画修复等企业级解决方案,助你突破渲染瓶颈,打造流畅UI体验。无论初学者还是进阶开发者,都能全面掌握ArkTS渲染控制精髓!适配HarmonyOS开发,助力教育科普与实践应用。
|
监控 数据可视化 安全
探究架构之 - 45张图玩转Kong Gateway,建议收藏系列 (一)
探究架构之 - 45张图玩转Kong Gateway,建议收藏系列 (一)
1589 1
探究架构之 - 45张图玩转Kong Gateway,建议收藏系列 (一)
|
存储 小程序 前端开发
微信小程序与Java后端实现微信授权登录功能
微信小程序极大地简化了登录注册流程。对于用户而言,仅仅需要点击授权按钮,便能够完成登录操作,无需经历繁琐的注册步骤以及输入账号密码等一系列复杂操作,这种便捷的登录方式极大地提升了用户的使用体验
3611 12
|
监控 DataX
DataX教程(09)- DataX是如何做到限速的?
DataX教程(09)- DataX是如何做到限速的?
1094 0
DataX教程(09)- DataX是如何做到限速的?
自动生成IE浏览器的xpath工具IEXPath
自动生成IE浏览器的xpath工具IEXPath
241 0
|
网络协议 Linux 网络架构
【网络技术】什么是CIDR
【网络技术】什么是CIDR
1438 0
|
Ubuntu Linux 编译器
嵌入式linux系统应用开发
嵌入式linux系统应用开发
370 1
|
机器学习/深度学习 数据采集 人工智能
机器学习之sklearn基础教程
【5月更文挑战第9天】Sklearn是Python热门机器学习库,提供丰富算法和预处理工具。本文深入讲解基础概念、核心理论、常见问题及解决策略。内容涵盖模型选择与训练、预处理、交叉验证、分类回归、模型评估、数据集划分、正则化、编码分类变量、特征选择与降维、集成学习、超参数调优、模型评估、保存加载及模型解释。学习Sklearn是迈入机器学习领域的关键。
426 3

热门文章

最新文章