神经网络入门-宝石分类

简介: 神经网络入门-宝石分类

任务描述:


本次实践是一个多分类任务,需要将照片中的宝石分别进行识别,完成宝石的识别


实践平台:百度AI实训平台-AI Studio、PaddlePaddle1.8.0 动态图


image.png


深度神经网络(DNN)


深度神经网络(Deep Neural Networks,简称DNN)是深度学习的基础,其结构为input、hidden(可有多层)、output,每层均为全连接。


image.png


数据集介绍


  • 数据集文件名为archive_train.zip,archive_test.zip。
  • 该数据集包含25个类别不同宝石的图像。
  • 这些类别已经分为训练和测试数据。
  • 图像大小不一,格式为.jpeg。

image.png

# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. This directory will be recovered automatically after resetting environment. 
!ls /home/aistudio/data
data55032  dataset
#导入需要的包
import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
import matplotlib.pyplot as plt


1、数据准备


'''
参数配置
'''
train_parameters = {
    "input_size": [3, 64, 64],                           #输入图片的shape
    "class_dim": -1,                                     #分类数
    'augment_path' : '/home/aistudio/augment',           #数据增强图片目录
    "src_path":"data/data55032/archive_train.zip",       #原始数据集路径
    "target_path":"/home/aistudio/data/dataset",        #要解压的路径 
    "train_list_path": "./train_data.txt",              #train_data.txt路径
    "eval_list_path": "./val_data.txt",                  #eval_data.txt路径
    "label_dict":{},                                    #标签字典
    "readme_path": "/home/aistudio/data/readme.json",   #readme.json路径
    "num_epochs": 20,                                    #训练轮数
    "train_batch_size": 64,                             #批次的大小
    "learning_strategy": {                              #优化函数相关的配置
        "lr": 0.001                                     #超参数学习率
    } 
}
def unzip_data(src_path,target_path):
    '''
    解压原始数据集,将src_path路径下的zip包解压至data/dataset目录下
    '''
    if(not os.path.isdir(target_path)):    
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()
    else:
        print("文件已解压")
def get_data_list(target_path,train_list_path,eval_list_path, augment_path):
    '''
    生成数据列表
    '''
    #存放所有类别的信息
    class_detail = []
    #获取所有类别保存的文件夹名称
    data_list_path=target_path
    class_dirs = os.listdir(data_list_path)
    if '__MACOSX' in class_dirs:
        class_dirs.remove('__MACOSX')
    # #总的图像数量
    all_class_images = 0
    # #存放类别标签
    class_label=0
    # #存放类别数目
    class_dim = 0
    # #存储要写进eval.txt和train.txt中的内容
    trainer_list=[]
    eval_list=[]
    #读取每个类别
    for class_dir in class_dirs:
        if class_dir != ".DS_Store":
            class_dim += 1
            #每个类别的信息
            class_detail_list = {}
            eval_sum = 0
            trainer_sum = 0
            #统计每个类别有多少张图片
            class_sum = 0
            #获取类别路径 
            path = os.path.join(data_list_path,class_dir)
            # print(path)
            # 获取所有图片
            img_paths = os.listdir(path)
            for img_path in img_paths:                                  # 遍历文件夹下的每个图片
                if img_path =='.DS_Store':
                    continue
                name_path = os.path.join(path,img_path)                       # 每张图片的路径
                if class_sum % 15 == 0:                                 # 每10张图片取一个做验证数据
                    eval_sum += 1                                       # eval_sum为测试数据的数目
                    eval_list.append(name_path + "\t%d" % class_label + "\n")
                else:
                    trainer_sum += 1 
                    trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                class_sum += 1                                          #每类图片的数目
                all_class_images += 1                                   #所有类图片的数目 
            # ----------------------------------数据增强----------------------------------
            aug_path = os.path.join(augment_path, class_dir)
            for img_path in os.listdir(aug_path):                                  # 遍历文件夹下的每个图片
                name_path = os.path.join(aug_path,img_path)                       # 每张图片的路径
                trainer_sum += 1 
                trainer_list.append(name_path + "\t%d" % class_label + "\n")#trainer_sum测试数据的数目
                all_class_images += 1                                   #所有类图片的数目
            # ----------------------------------------------------------------------------
            # 说明的json文件的class_detail数据
            class_detail_list['class_name'] = class_dir             #类别名称
            class_detail_list['class_label'] = class_label          #类别标签
            class_detail_list['class_eval_images'] = eval_sum       #该类数据的测试集数目
            class_detail_list['class_trainer_images'] = trainer_sum #该类数据的训练集数目
            class_detail.append(class_detail_list)  
            #初始化标签列表
            train_parameters['label_dict'][str(class_label)] = class_dir
            class_label += 1
    #初始化分类数
    train_parameters['class_dim'] = class_dim
    print(train_parameters)
    #乱序  
    random.shuffle(eval_list)
    with open(eval_list_path, 'a') as f:
        for eval_image in eval_list:
            f.write(eval_image) 
    #乱序        
    random.shuffle(trainer_list) 
    with open(train_list_path, 'a') as f2:
        for train_image in trainer_list:
            f2.write(train_image) 
    # 说明的json文件信息
    readjson = {}
    readjson['all_class_name'] = data_list_path                  #文件父目录
    readjson['all_class_images'] = all_class_images
    readjson['class_detail'] = class_detail
    jsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(',', ': '))
    with open(train_parameters['readme_path'],'w') as f:
        f.write(jsons)
    print ('生成数据列表完成!')
def data_reader(file_list):
    '''
    自定义data_reader
    '''
    def reader():
        with open(file_list, 'r') as f:
            lines = [line.strip() for line in f]
            for line in lines:
                img_path, lab = line.strip().split('\t')
                img = Image.open(img_path) 
                if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                img = img.resize((64, 64), Image.BILINEAR)
                img = np.array(img).astype('float32') 
                img = img.transpose((2, 0, 1))  # HWC to CHW 
                img = img/255                   # 像素值归一化 
                yield img, int(lab) 
    return reader
!pip install Augmentor
Looking in indexes: https://mirror.baidu.com/pypi/simple/
Requirement already satisfied: Augmentor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (0.2.8)
Requirement already satisfied: tqdm>=4.9.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (4.36.1)
Requirement already satisfied: future>=0.16.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (0.18.0)
Requirement already satisfied: numpy>=1.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (1.16.4)
Requirement already satisfied: Pillow>=5.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Augmentor) (7.1.2)
'''
参数初始化
'''
src_path=train_parameters['src_path']
target_path=train_parameters['target_path']
train_list_path=train_parameters['train_list_path']
eval_list_path=train_parameters['eval_list_path']
batch_size=train_parameters['train_batch_size']
augment_path = train_parameters['augment_path']
'''
解压原始数据到指定路径
'''
unzip_data(src_path,target_path)
文件已解压
def proc_img(src):
    for root, dirs, files in os.walk(src):
        if '__MACOSX' in root:continue
        for file in files:            
            src=os.path.join(root,file)
            img=Image.open(src)
            if img.mode != 'RGB': 
                    img = img.convert('RGB') 
                    img.save(src)            
if __name__=='__main__':
    proc_img(r"data/dataset")
import os, Augmentor
import shutil, glob
if not os.path.exists(augment_path): # 控制不重复增强数据
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            path_ = os.path.join(root, name)
            if '__MACOSX' in path_:continue
            print('数据增强:',os.path.join(root, name))
            print('image:',os.path.join(root, name))
            p = Augmentor.Pipeline(os.path.join(root, name),output_directory='output')
            p.rotate(probability=0.6, max_left_rotation=2, max_right_rotation=2)
            p.zoom(probability=0.6, min_factor=0.9, max_factor=1.1)
            p.random_distortion(probability=0.4, grid_height=2, grid_width=2, magnitude=1)
            count = 1000 - len(glob.glob(pathname=path_+'/*.jpg'))
            p.sample(count, multi_threaded=False)
            p.process()
    print('将生成的图片拷贝到正确的目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in files:
            path_ = os.path.join(root, name)
            if path_.rsplit('/',3)[2] == 'output':
                type_ = path_.rsplit('/',3)[1]
                dest_dir = os.path.join(augment_path ,type_) 
                if not os.path.exists(dest_dir):os.makedirs(dest_dir) 
                dest_path_ = os.path.join(augment_path ,type_, name) 
                shutil.move(path_, dest_path_)
    print('删除所有output目录')
    for root, dirs, files in os.walk("data/dataset", topdown=False):
        for name in dirs:
            if name == 'output':
                path_ = os.path.join(root, name)
                shutil.rmtree(path_)
    print('完成数据增强')
Processing kunzite_20.jpg:   1%|          | 11/968 [00:00<00:14, 65.61 Samples/s]
数据增强: data/dataset/Kunzite
image: data/dataset/Kunzite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Kunzite/output.
Processing kunzite_14.jpg:   2%|▏         | 24/968 [00:00<00:17, 54.43 Samples/s]Processing kunzite_15.jpg: 100%|██████████| 968/968 [00:15<00:00, 61.57 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=350x366 at 0x7F7060EB06D0>: 100%|██████████| 32/32 [00:00<00:00, 269.33 Samples/s]                  
Processing almandine_5.jpg:   1%|          | 6/969 [00:00<00:20, 45.91 Samples/s] 
数据增强: data/dataset/Almandine
image: data/dataset/Almandine
Initialised with 31 image(s) found.
Output directory set to data/dataset/Almandine/output.
Processing almandine_2.jpg:   1%|▏         | 14/969 [00:00<00:27, 34.12 Samples/s] Processing almandine_25.jpg: 100%|██████████| 969/969 [00:22<00:00, 42.25 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E020C90>: 100%|██████████| 31/31 [00:00<00:00, 173.21 Samples/s]                
Processing emerald_2.jpg:   1%|          | 10/964 [00:00<00:16, 58.72 Samples/s]
数据增强: data/dataset/Emerald
image: data/dataset/Emerald
Initialised with 36 image(s) found.
Output directory set to data/dataset/Emerald/output.
Processing emerald_36.jpg:   2%|▏         | 20/964 [00:00<00:17, 54.08 Samples/s]Processing emerald_15.jpg: 100%|██████████| 964/964 [00:26<00:00, 36.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=460x460 at 0x7F705DED0110>: 100%|██████████| 36/36 [00:00<00:00, 149.48 Samples/s]                   
Processing sapphire blue_9.jpg:   1%|          | 10/966 [00:00<00:13, 68.91 Samples/s]
数据增强: data/dataset/Sapphire Blue
image: data/dataset/Sapphire Blue
Initialised with 34 image(s) found.
Output directory set to data/dataset/Sapphire Blue/output.
Processing sapphire blue_16.jpg:   2%|▏         | 22/966 [00:00<00:16, 56.52 Samples/s]Processing sapphire blue_30.jpg: 100%|██████████| 966/966 [00:18<00:00, 53.08 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=450x450 at 0x7F706885B810>: 100%|██████████| 34/34 [00:00<00:00, 177.29 Samples/s]                  
Processing malachite_2.jpg:   1%|          | 10/972 [00:00<00:20, 47.64 Samples/s]
数据增强: data/dataset/Malachite
image: data/dataset/Malachite
Initialised with 28 image(s) found.
Output directory set to data/dataset/Malachite/output.
Processing malachite_16.jpg:   2%|▏         | 18/972 [00:00<00:20, 47.14 Samples/s]Processing malachite_22.jpg: 100%|██████████| 972/972 [00:18<00:00, 52.32 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=376x262 at 0x7F7060E93D10>: 100%|██████████| 28/28 [00:00<00:00, 173.34 Samples/s]
Processing alexandrite_0.jpg:   1%|          | 6/966 [00:00<00:24, 39.61 Samples/s] 
数据增强: data/dataset/Alexandrite
image: data/dataset/Alexandrite
Initialised with 34 image(s) found.
Output directory set to data/dataset/Alexandrite/output.
Processing alexandrite_23.jpg:   2%|▏         | 18/966 [00:00<00:21, 44.52 Samples/s]Processing alexandrite_20.jpg: 100%|██████████| 966/966 [00:20<00:00, 48.06 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705E025B10>: 100%|██████████| 34/34 [00:00<00:00, 129.49 Samples/s]                 
Processing zircon_8.jpg:   1%|          | 5/967 [00:00<00:33, 28.43 Samples/s] 
数据增强: data/dataset/Zircon
image: data/dataset/Zircon
Initialised with 33 image(s) found.
Output directory set to data/dataset/Zircon/output.
Processing zircon_23.jpg:   1%|          | 6/967 [00:00<00:33, 28.43 Samples/s]Processing zircon_24.jpg: 100%|██████████| 967/967 [00:24<00:00, 38.88 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=500x500 at 0x7F705DEAC3D0>: 100%|██████████| 33/33 [00:00<00:00, 134.76 Samples/s]                 
Processing onyx black_16.jpg:   1%|          | 8/972 [00:00<00:13, 69.17 Samples/s]
数据增强: data/dataset/Onyx Black
image: data/dataset/Onyx Black
Initialised with 28 image(s) found.
Output directory set to data/dataset/Onyx Black/output.
Processing onyx black_6.jpg:   2%|▏         | 18/972 [00:00<00:18, 51.84 Samples/s] Processing onyx black_2.jpg: 100%|██████████| 972/972 [00:18<00:00, 53.19 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DEE1910>: 100%|██████████| 28/28 [00:00<00:00, 131.50 Samples/s]                 
Processing rhodochrosite_29.jpg:   1%|          | 10/971 [00:00<00:18, 53.20 Samples/s]
数据增强: data/dataset/Rhodochrosite
image: data/dataset/Rhodochrosite
Initialised with 29 image(s) found.
Output directory set to data/dataset/Rhodochrosite/output.
Processing rhodochrosite_21.jpg:   2%|▏         | 21/971 [00:00<00:16, 58.01 Samples/s]Processing rhodochrosite_15.jpg: 100%|██████████| 971/971 [00:20<00:00, 46.42 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=373x356 at 0x7F705E011910>: 100%|██████████| 29/29 [00:00<00:00, 243.76 Samples/s]                  
Processing diamond_16.jpg:   1%|          | 5/969 [00:00<00:28, 34.31 Samples/s]
数据增强: data/dataset/Diamond
image: data/dataset/Diamond
Initialised with 31 image(s) found.
Output directory set to data/dataset/Diamond/output.
Processing diamond_6.jpg:   1%|          | 11/969 [00:00<00:26, 35.79 Samples/s] Processing diamond_20.jpg: 100%|██████████| 969/969 [00:24<00:00, 40.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE6CCD0>: 100%|██████████| 31/31 [00:00<00:00, 150.83 Samples/s]                
Processing benitoite_29.jpg:   1%|          | 7/969 [00:00<00:15, 63.04 Samples/s]
数据增强: data/dataset/Benitoite
image: data/dataset/Benitoite
Initialised with 31 image(s) found.
Output directory set to data/dataset/Benitoite/output.
Processing benitoite_2.jpg:   2%|▏         | 24/969 [00:00<00:16, 57.15 Samples/s] Processing benitoite_12.jpg: 100%|██████████| 969/969 [00:17<00:00, 55.09 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=472x433 at 0x7F705DFE9290>: 100%|██████████| 31/31 [00:00<00:00, 178.70 Samples/s]                 
Processing pearl_0.jpg:   1%|          | 6/967 [00:00<00:25, 38.13 Samples/s] 
数据增强: data/dataset/Pearl
image: data/dataset/Pearl
Initialised with 33 image(s) found.
Output directory set to data/dataset/Pearl/output.
Processing pearl_32.jpg:   2%|▏         | 21/967 [00:00<00:20, 47.09 Samples/s]Processing pearl_12.jpg: 100%|██████████| 967/967 [00:17<00:00, 54.49 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020A50>: 100%|██████████| 33/33 [00:00<00:00, 205.47 Samples/s]                 
Processing beryl golden_39.jpg:   1%|          | 11/964 [00:00<00:12, 79.36 Samples/s]
数据增强: data/dataset/Beryl Golden
image: data/dataset/Beryl Golden
Initialised with 36 image(s) found.
Output directory set to data/dataset/Beryl Golden/output.
Processing beryl golden_29.jpg:   2%|▏         | 22/964 [00:00<00:14, 63.92 Samples/s]Processing beryl golden_2.jpg: 100%|██████████| 964/964 [00:16<00:00, 58.61 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE6F910>: 100%|██████████| 36/36 [00:00<00:00, 273.71 Samples/s]                
Processing labradorite_16.jpg:   1%|          | 9/960 [00:00<00:17, 55.49 Samples/s]
数据增强: data/dataset/Labradorite
image: data/dataset/Labradorite
Initialised with 40 image(s) found.
Output directory set to data/dataset/Labradorite/output.
Processing labradorite_17.jpg:   2%|▏         | 20/960 [00:00<00:18, 52.03 Samples/s]Processing labradorite_11.jpg: 100%|██████████| 960/960 [00:21<00:00, 45.63 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=400x400 at 0x7F705DE70F10>: 100%|██████████| 40/40 [00:00<00:00, 117.40 Samples/s]                 
Processing fluorite_23.jpg:   1%|          | 11/968 [00:00<00:14, 65.24 Samples/s]
数据增强: data/dataset/Fluorite
image: data/dataset/Fluorite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Fluorite/output.
Processing fluorite_4.jpg:   1%|▏         | 14/968 [00:00<00:19, 49.03 Samples/s] Processing fluorite_4.jpg: 100%|██████████| 968/968 [00:21<00:00, 44.39 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=500x442 at 0x7F705DE87CD0>: 100%|██████████| 32/32 [00:00<00:00, 169.43 Samples/s]                 
Processing iolite_2.jpg:   1%|          | 7/968 [00:00<00:24, 39.15 Samples/s] 
数据增强: data/dataset/Iolite
image: data/dataset/Iolite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Iolite/output.
Processing iolite_35.jpg:   2%|▏         | 23/968 [00:00<00:18, 51.39 Samples/s]Processing iolite_23.jpg: 100%|██████████| 968/968 [00:16<00:00, 57.22 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE764D0>: 100%|██████████| 32/32 [00:00<00:00, 373.16 Samples/s]                  
Processing quartz beer_24.jpg:   1%|          | 12/965 [00:00<00:16, 57.87 Samples/s]
数据增强: data/dataset/Quartz Beer
image: data/dataset/Quartz Beer
Initialised with 35 image(s) found.
Output directory set to data/dataset/Quartz Beer/output.
Processing quartz beer_28.jpg:   2%|▏         | 24/965 [00:00<00:14, 65.30 Samples/s]Processing quartz beer_30.jpg: 100%|██████████| 965/965 [00:16<00:00, 59.48 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=300x300 at 0x7F705DE82DD0>: 100%|██████████| 35/35 [00:00<00:00, 173.58 Samples/s]                 
Processing garnet red_21.jpg:   1%|          | 7/964 [00:00<00:34, 27.76 Samples/s]
数据增强: data/dataset/Garnet Red
image: data/dataset/Garnet Red
Initialised with 36 image(s) found.
Output directory set to data/dataset/Garnet Red/output.
Processing garnet red_2.jpg:   2%|▏         | 17/964 [00:00<00:28, 33.50 Samples/s] Processing garnet red_2.jpg: 100%|██████████| 964/964 [00:20<00:00, 46.97 Samples/s] 
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020090>: 100%|██████████| 36/36 [00:00<00:00, 197.00 Samples/s]                 
Processing danburite_35.jpg:   1%|          | 8/968 [00:00<00:16, 58.65 Samples/s]
数据增强: data/dataset/Danburite
image: data/dataset/Danburite
Initialised with 32 image(s) found.
Output directory set to data/dataset/Danburite/output.
Processing danburite_32.jpg:   2%|▏         | 17/968 [00:00<00:19, 49.88 Samples/s]Processing danburite_23.jpg: 100%|██████████| 968/968 [00:19<00:00, 50.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE78390>: 100%|██████████| 32/32 [00:00<00:00, 144.25 Samples/s]                 
Processing cats eye_7.jpg:   1%|          | 8/969 [00:00<00:24, 39.01 Samples/s] 
数据增强: data/dataset/Cats Eye
image: data/dataset/Cats Eye
Initialised with 31 image(s) found.
Output directory set to data/dataset/Cats Eye/output.
Processing cats eye_26.jpg:   2%|▏         | 15/969 [00:00<00:23, 41.33 Samples/s]Processing cats eye_33.jpg: 100%|██████████| 969/969 [00:25<00:00, 38.19 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=401x401 at 0x7F706AF09510>: 100%|██████████| 31/31 [00:00<00:00, 214.03 Samples/s]                 
Processing hessonite_1.jpg:   0%|          | 3/970 [00:00<00:33, 28.84 Samples/s] 
数据增强: data/dataset/Hessonite
image: data/dataset/Hessonite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Hessonite/output.
Processing hessonite_19.jpg:   1%|▏         | 13/970 [00:00<00:31, 30.34 Samples/s]Processing hessonite_33.jpg: 100%|██████████| 970/970 [00:20<00:00, 47.73 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=301x301 at 0x7F705E020610>: 100%|██████████| 30/30 [00:00<00:00, 162.33 Samples/s]                 
Processing carnelian_12.jpg:   1%|          | 5/967 [00:00<00:28, 34.19 Samples/s]
数据增强: data/dataset/Carnelian
image: data/dataset/Carnelian
Initialised with 33 image(s) found.
Output directory set to data/dataset/Carnelian/output.
Processing carnelian_32.jpg:   1%|          | 12/967 [00:00<00:29, 32.65 Samples/s]Processing carnelian_31.jpg: 100%|██████████| 967/967 [00:24<00:00, 39.93 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=425x425 at 0x7F705DE840D0>: 100%|██████████| 33/33 [00:00<00:00, 147.85 Samples/s]                 
Processing jade_26.jpg:   1%|          | 9/972 [00:00<00:25, 38.24 Samples/s]
数据增强: data/dataset/Jade
image: data/dataset/Jade
Initialised with 28 image(s) found.
Output directory set to data/dataset/Jade/output.
Processing jade_20.jpg:   2%|▏         | 22/972 [00:00<00:19, 47.93 Samples/s]Processing jade_18.jpg: 100%|██████████| 972/972 [00:18<00:00, 51.18 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=290x290 at 0x7F705DE8B050>: 100%|██████████| 28/28 [00:00<00:00, 331.02 Samples/s]                  
Processing variscite_22.jpg:   1%|          | 5/970 [00:00<00:25, 37.31 Samples/s]
数据增强: data/dataset/Variscite
image: data/dataset/Variscite
Initialised with 30 image(s) found.
Output directory set to data/dataset/Variscite/output.
Processing variscite_10.jpg:   1%|▏         | 13/970 [00:00<00:26, 35.70 Samples/s]Processing variscite_31.jpg: 100%|██████████| 970/970 [00:21<00:00, 45.58 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705DE7BE50>: 100%|██████████| 30/30 [00:00<00:00, 157.22 Samples/s]                 
Processing tanzanite_2.jpg:   1%|          | 5/964 [00:00<00:31, 30.52 Samples/s] 
数据增强: data/dataset/Tanzanite
image: data/dataset/Tanzanite
Initialised with 36 image(s) found.
Output directory set to data/dataset/Tanzanite/output.
Processing tanzanite_15.jpg:   2%|▏         | 15/964 [00:00<00:25, 36.60 Samples/s]Processing tanzanite_37.jpg: 100%|██████████| 964/964 [00:25<00:00, 38.41 Samples/s]
Processing <PIL.Image.Image image mode=RGB size=225x225 at 0x7F705E00E4D0>: 100%|██████████| 36/36 [00:00<00:00, 144.18 Samples/s]                 
将生成的图片拷贝到正确的目录
删除所有output目录
完成数据增强
#每次生成数据列表前,首先清空train.txt和eval.txt
with open(train_list_path, 'w') as f: 
    f.seek(0)
    f.truncate() 
with open(eval_list_path, 'w') as f: 
    f.seek(0)
    f.truncate() 
#生成数据列表   
get_data_list(target_path,train_list_path,eval_list_path,augment_path)
'''
构造数据提供器
'''
train_reader = paddle.batch(data_reader(train_list_path),
                            batch_size=batch_size,
                            drop_last=True)
eval_reader = paddle.batch(data_reader(eval_list_path),
                            batch_size=batch_size,
                            drop_last=True)
{'input_size': [3, 64, 64], 'class_dim': 25, 'augment_path': '/home/aistudio/augment', 'src_path': 'data/data55032/archive_train.zip', 'target_path': '/home/aistudio/data/dataset', 'train_list_path': './train_data.txt', 'eval_list_path': './val_data.txt', 'label_dict': {'0': 'Kunzite', '1': 'Almandine', '2': 'Emerald', '3': 'Sapphire Blue', '4': 'Malachite', '5': 'Alexandrite', '6': 'Zircon', '7': 'Onyx Black', '8': 'Rhodochrosite', '9': 'Diamond', '10': 'Benitoite', '11': 'Pearl', '12': 'Beryl Golden', '13': 'Labradorite', '14': 'Fluorite', '15': 'Iolite', '16': 'Quartz Beer', '17': 'Garnet Red', '18': 'Danburite', '19': 'Cats Eye', '20': 'Hessonite', '21': 'Carnelian', '22': 'Jade', '23': 'Variscite', '24': 'Tanzanite'}, 'readme_path': '/home/aistudio/data/readme.json', 'num_epochs': 20, 'train_batch_size': 64, 'learning_strategy': {'lr': 0.001}}
生成数据列表完成!
Batch=0
Batchs=[]
all_train_accs=[]
def draw_train_acc(Batchs, train_accs):
    title="training accs"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("acc", fontsize=14)
    plt.plot(Batchs, train_accs, color='green', label='training accs')
    plt.legend()
    plt.grid()
    plt.show()
all_train_loss=[]
def draw_train_loss(Batchs, train_loss):
    title="training loss"
    plt.title(title, fontsize=24)
    plt.xlabel("batch", fontsize=14)
    plt.ylabel("loss", fontsize=14)
    plt.plot(Batchs, train_loss, color='red', label='training loss')
    plt.legend()
    plt.grid()
    plt.show()


2、定义模型


###在以下cell中完成DNN网络的定义###


#定义网络
class MyDNN(fluid.dygraph.Layer):
    '''
    卷积神经网络
    '''
    def __init__(self):
        super(MyDNN,self).__init__()
        self.hidden1=fluid.dygraph.Linear(3*64*64,1000, act='relu')
        self.hidden2=fluid.dygraph.Linear(1000,500, act='relu')
        self.hidden3=fluid.dygraph.Linear(500,100, act='relu')
        self.out = fluid.dygraph.Linear(input_dim=100, output_dim=25, act='softmax')
    def forward(self,input):
        x = fluid.layers.reshape(input,shape=[-1,3*64*64])
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.hidden3(x)
        x = self.out(x)
        return x


3、训练模型


with fluid.dygraph.guard(place = fluid.CUDAPlace(0)):
    print(train_parameters['class_dim'])
    print(train_parameters['label_dict'])
    model=MyDNN() #模型实例化
    model.train() #训练模式
    opt=fluid.optimizer.SGDOptimizer(learning_rate=train_parameters['learning_strategy']['lr'], parameter_list=model.parameters())#优化器选用SGD随机梯度下降,学习率为0.001.
    epochs_num=train_parameters['num_epochs'] #迭代次数
    for pass_num in range(epochs_num):
        for batch_id,data in enumerate(train_reader()):
            images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
            labels = np.array([x[1] for x in data]).astype('int64')
            labels = labels[:, np.newaxis]
            image=fluid.dygraph.to_variable(images)
            label=fluid.dygraph.to_variable(labels)
            predict=model(image) #数据传入model
            loss=fluid.layers.cross_entropy(predict,label)
            avg_loss=fluid.layers.mean(loss)#获取loss值
            acc=fluid.layers.accuracy(predict,label)#计算精度
            if batch_id!=0 and batch_id%5==0:
                Batch = Batch+5 
                Batchs.append(Batch)
                all_train_loss.append(avg_loss.numpy()[0])
                all_train_accs.append(acc.numpy()[0])
                print("train_pass:{},batch_id:{},train_loss:{},train_acc:{}".format(pass_num,batch_id,avg_loss.numpy(),acc.numpy()))
            avg_loss.backward()       
            opt.minimize(avg_loss)    #优化器对象的minimize方法对参数进行更新 
            model.clear_gradients()   #model.clear_gradients()来重置梯度
    fluid.save_dygraph(model.state_dict(),'MyDNN')#保存模型
draw_train_acc(Batchs,all_train_accs)
draw_train_loss(Batchs,all_train_loss)
train_pass:19,batch_id:400,train_loss:[0.24890603],train_acc:[0.96875]


image.pngimage.png


4、模型评估


#模型评估
with fluid.dygraph.guard():
    accs = []
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式
    for batch_id,data in enumerate(eval_reader()):#测试集
        images = np.array([x[0] for x in data]).astype('float32').reshape(-1, 3,64,64)
        labels = np.array([x[1] for x in data]).astype('int64')
        labels = labels[:, np.newaxis]
        image=fluid.dygraph.to_variable(images)
        label=fluid.dygraph.to_variable(labels)       
        predict=model(image)       
        acc=fluid.layers.accuracy(predict,label)
        accs.append(acc.numpy()[0])
        avg_acc = np.mean(accs)
    print(avg_acc)
0.96875


5、模型预测


import os
import zipfile
def unzip_infer_data(src_path,target_path):
    '''
    解压预测数据集
    '''
    if(not os.path.isdir(target_path)):     
        z = zipfile.ZipFile(src_path, 'r')
        z.extractall(path=target_path)
        z.close()
def load_image(img_path):
    '''
    预测图片预处理
    '''
    img = Image.open(img_path) 
    if img.mode != 'RGB': 
        img = img.convert('RGB') 
    img = img.resize((64, 64), Image.BILINEAR)
    img = np.array(img).astype('float32') 
    img = img.transpose((2, 0, 1))  # HWC to CHW 
    img = img/255                # 像素值归一化 
    return img
infer_src_path = '/home/aistudio/data/data55032/archive_test.zip'
infer_dst_path = '/home/aistudio/data/archive_test'
unzip_infer_data(infer_src_path,infer_dst_path)
label_dic = train_parameters['label_dict']
'''
模型预测
'''
with fluid.dygraph.guard():
    model_dict, _ = fluid.load_dygraph('MyDNN')
    model = MyDNN()
    model.load_dict(model_dict) #加载模型参数
    model.eval() #训练模式
    #展示预测图片
    infer_path='data/archive_test/alexandrite_3.jpg'
    img = Image.open(infer_path)
    plt.imshow(img)          #根据数组绘制图像
    plt.show()               #显示图像
    #对预测图片进行预处理
    infer_imgs = []
    infer_imgs.append(load_image(infer_path))
    infer_imgs = np.array(infer_imgs)
    for i in range(len(infer_imgs)):
        data = infer_imgs[i]
        dy_x_data = np.array(data).astype('float32')
        dy_x_data=dy_x_data[np.newaxis,:, : ,:]
        img = fluid.dygraph.to_variable(dy_x_data)
        out = model(img)
        lab = np.argmax(out.numpy())  #argmax():返回最大数的索引
        print("第{}个样本,被预测为:{},真实标签为:{}".format(i+1,label_dic[str(lab)],infer_path.split('/')[-1].split("_")[0]))
print("结束")

image.png

第1个样本,被预测为:Malachite,真实标签为:alexandrite
结束


目录
相关文章
|
21天前
|
机器学习/深度学习 资源调度 算法
图卷积网络入门:数学基础与架构设计
本文系统地阐述了图卷积网络的架构原理。通过简化数学表述并聚焦于矩阵运算的核心概念,详细解析了GCN的工作机制。
54 3
图卷积网络入门:数学基础与架构设计
|
11天前
|
Web App开发 网络协议 安全
网络编程懒人入门(十六):手把手教你使用网络编程抓包神器Wireshark
Wireshark是一款开源和跨平台的抓包工具。它通过调用操作系统底层的API,直接捕获网卡上的数据包,因此捕获的数据包详细、功能强大。但Wireshark本身稍显复杂,本文将以用抓包实例,手把手带你一步步用好Wireshark,并真正理解抓到的数据包的各项含义。
48 2
|
2月前
|
网络协议
计算机网络的分类
【10月更文挑战第11天】 计算机网络可按覆盖范围(局域网、城域网、广域网)、传输技术(有线、无线)、拓扑结构(星型、总线型、环型、网状型)、使用者(公用、专用)、交换方式(电路交换、分组交换)和服务类型(面向连接、无连接)等多种方式进行分类,每种分类方式揭示了网络的不同特性和应用场景。
|
14天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于深度学习网络的宝石类型识别算法matlab仿真
本项目利用GoogLeNet深度学习网络进行宝石类型识别,实验包括收集多类宝石图像数据集并按7:1:2比例划分。使用Matlab2022a实现算法,提供含中文注释的完整代码及操作视频。GoogLeNet通过其独特的Inception模块,结合数据增强、学习率调整和正则化等优化手段,有效提升了宝石识别的准确性和效率。
|
18天前
|
机器学习/深度学习 人工智能 算法
深度学习入门:用Python构建你的第一个神经网络
在人工智能的海洋中,深度学习是那艘能够带你远航的船。本文将作为你的航标,引导你搭建第一个神经网络模型,让你领略深度学习的魅力。通过简单直观的语言和实例,我们将一起探索隐藏在数据背后的模式,体验从零开始创造智能系统的快感。准备好了吗?让我们启航吧!
44 3
|
21天前
|
机器学习/深度学习 Serverless 索引
分类网络中one-hot编码的作用
在分类任务中,使用神经网络时,通常需要将类别标签转换为一种合适的输入格式。这时候,one-hot编码(one-hot encoding)是一种常见且有效的方法。one-hot编码将类别标签表示为向量形式,其中只有一个元素为1,其他元素为0。
25 2
|
25天前
|
数据采集 XML 存储
构建高效的Python网络爬虫:从入门到实践
本文旨在通过深入浅出的方式,引导读者从零开始构建一个高效的Python网络爬虫。我们将探索爬虫的基本原理、核心组件以及如何利用Python的强大库进行数据抓取和处理。文章不仅提供理论指导,还结合实战案例,让读者能够快速掌握爬虫技术,并应用于实际项目中。无论你是编程新手还是有一定基础的开发者,都能在这篇文章中找到有价值的内容。
|
1月前
|
机器学习/深度学习 人工智能 算法框架/工具
深度学习中的卷积神经网络(CNN)入门
【10月更文挑战第41天】在人工智能的璀璨星空下,卷积神经网络(CNN)如一颗耀眼的新星,照亮了图像处理和视觉识别的路径。本文将深入浅出地介绍CNN的基本概念、核心结构和工作原理,同时提供代码示例,带领初学者轻松步入这一神秘而又充满无限可能的领域。
|
1月前
|
消息中间件 编解码 网络协议
Netty从入门到精通:高性能网络编程的进阶之路
【11月更文挑战第17天】Netty是一个基于Java NIO(Non-blocking I/O)的高性能、异步事件驱动的网络应用框架。使用Netty,开发者可以快速、高效地开发可扩展的网络服务器和客户端程序。本文将带您从Netty的背景、业务场景、功能点、解决问题的关键、底层原理实现,到编写一个详细的Java示例,全面了解Netty,帮助您从入门到精通。
127 0
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
61 3