阿里云平台cifar10代码解析

本文涉及的产品
对象存储 OSS,20GB 3个月
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
简介:

阿里提供的代码,但是没有解释,花了好长时间才把里面看了个大概,但还未完全掌握。共享我的见解,也请看的同志帮忙修正,共勉!

Tflearn  https://github.com/tflearn/tflearn

 

from __future__ import division, print_function, absolute_import __future__ 

#模块是包含python未来特性的模块,如果你用的是python2,那你就可以通过导入这个模块使用python3的特性

import tensorflow as tf

from six.moves import urllib

import tarfile

import tflearn

from tflearn.data_utils import shuffle, to_categorical

from tflearn.layers.core import input_data, dropout, fully_connected

from tflearn.layers.conv import conv_2d, max_pool_2d

from tflearn.layers.estimator import regression

from tflearn.data_preprocessing import ImagePreprocessing

from tflearn.data_augmentation import ImageAugmentation

#data_augmentation方法与data_preprocessing方法在训练阶段相似,详见#data_augmentation。对input_data方法处理

 

from tensorflow.python.lib.io import file_io

import os

import sys

import numpy as np

import pickle

import argparse

import scipy

FLAGS = None

 

def load_data(dirname, one_hot=False):

    X_train = []

    Y_train = []

 

for i in range(1, 6):

#1,2,3,4,5

        fpath = os.path.join(dirname, 'data_batch_' + str(i))

#连接:将dirname和后面的'data_batch_' + str(i)进行拼接,得到文件夹中文件的路径

        data, labels = load_batch(fpath)

#录入文件并得到data, labels

#经过解压得知 'data_batch_' + str(i)得到的是训练文件

        if i == 1:

            X_train = data

            Y_train = labels

        else:

            X_train = np.concatenate([X_train, data], axis=0)

#沿着某个轴拼接矩阵

            Y_train = np.concatenate([Y_train, labels], axis=0)

#将所有片段拼接在一起,返回的是ndarray

    fpath = os.path.join(dirname, 'test_batch')

    X_test, Y_test = load_batch(fpath)

#3通道分离shape为(10000,1024,3),为reshape做准备

#np.dstack(tup)等价于np.concatenate(tup,axis=2)即在第三维进行拼接

#(X_train[:, :1024], X_train[:, 1024:2048],X_train[:, 2048:])tup

    X_train = np.dstack((X_train[:, :1024], X_train[:, 1024:2048],

                         X_train[:, 2048:])) / 255.

    X_train = np.reshape(X_train, [-1, 32, 32, 3])

    X_test = np.dstack((X_test[:, :1024], X_test[:, 1024:2048],

                        X_test[:, 2048:])) / 255.

    #uint8  无符号整数,0 至 255,处理后,每个元素都小于等于1

    X_test = np.reshape(X_test, [-1, 32, 32, 3])

#stack(堆叠)

if one_hot:

#根据需要,看是否要转化成独热编码

        Y_train = to_categorical(Y_train, 10)

        Y_test = to_categorical(Y_test, 10)

 

    return (X_train, Y_train), (X_test, Y_test)

 

#reporthook from stackoverflow #13881092

def reporthook(blocknum, blocksize, totalsize):

    readsofar = blocknum * blocksize

    if totalsize > 0:

        percent = readsofar * 1e2 / totalsize

        s = "\r%5.1f%% %*d / %d" % (

            percent, len(str(totalsize)), readsofar, totalsize)

        sys.stderr.write(s)#重定向标准错误信息

        if readsofar >= totalsize: # near the end

            sys.stderr.write("\n")

    else: # total size is unknown

        sys.stderr.write("read %d\n" % (readsofar,))

 

def load_batch(fpath):

#录入文件路径,返回data,labels

object = file_io.read_file_to_string(fpath) 

#文件内容转化成字符串或者字节fpath需是文件路径

    #origin_bytes = bytes(object, encoding='latin1')

    # with open(fpath, 'rb') as f:

if sys.version_info > (3, 0):

#如果大于3.0版本  sys.version_info返回sys.version_info

#major=3,minor=6,micro=2,releaselevel=final,serial=0

        # Python3

        d = pickle.loads(object, encoding='latin1') 

#反序列化。。。尝试将object = file_io.read_file_to_string(fpath) 

#改成pickle.dumps()进行序列化  encoding="bytes"

 

    else:

        # Python2

        d = pickle.loads(object)

    data = d["data"]#data.shape (10000,3072)

    labels = d["labels"]

    return data, labels

 

def main(_):

dirname = os.path.join(FLAGS.buckets, "")

  #Namespace(buckets='oss://.../.../.../.../', 

  #checkpointDir='oss://.../.../.../check_point/model/')

#print('dirname:',dirname)

#dirname: oss://.../.../.../.../

    (X, Y), (X_test, Y_test) = load_data(dirname)

    print("load data done")

 

X, Y = shuffle(X, Y)

#tflearn.data_utils.shuffle*arrs)每个矩阵按第一维一致打乱

Y = to_categorical(Y, 10)

#tflearn.data_utils.to_categoricaly,nb_classes),y矩阵,nb_classes分类数

    Y_test = to_categorical(Y_test, 10)

 

    # Real-time data preprocessing

    img_prep = ImagePreprocessing()

    img_prep.add_featurewise_zero_center()#零中心分布

    img_prep.add_featurewise_stdnorm()#标准偏离 standard deviation

 

    # Real-time data augmentation

    img_aug = ImageAugmentation()

    img_aug.add_random_flip_leftright()#随机左右翻转

    img_aug.add_random_rotation(max_angle=25.)#按随机角度旋转,最大旋转角度25

 

    # Convolutional network building

    network = input_data(shape=[None, 32, 32, 3],

                         data_preprocessing=img_prep,

                         data_augmentation=img_aug)

    network = conv_2d(network, 32, 3, activation='relu')

    network = max_pool_2d(network, 2)

    network = conv_2d(network, 64, 3, activation='relu')

    network = conv_2d(network, 64, 3, activation='relu')

    network = max_pool_2d(network, 2)

    network = fully_connected(network, 512, activation='relu')

    network = dropout(network, 0.5)

    network = fully_connected(network, 10, activation='softmax')

    network = regression(network, optimizer='adam',

                         loss='categorical_crossentropy',

                         learning_rate=0.001)

 

    # Train using classifier

    model = tflearn.DNN(network, tensorboard_verbose=0)

    # model.fit(X, Y, n_epoch=100, shuffle=True, validation_set=(X_test, Y_test),

    #           show_metric=True, batch_size=96, run_id='cifar10_cnn')

    model_path = os.path.join(FLAGS.checkpointDir, "model.tfl")

print(model_path)

#print('model_path:',model_path)

##model_path: #oss://.../.../.../check_point/model/model.tf2

    model.load(model_path)

 

    # predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")

    # file_paths = tf.train.match_filenames_once(predict_pic)

    # input_file_queue = tf.train.string_input_producer(file_paths)

    # reader = tf.WholeFileReader()

    # file_path, raw_data = reader.read(input_file_queue)

    # img = tf.image.decode_jpeg(raw_data, 3)

    # img = tf.image.resize_images(img, [32, 32])

    # prediction = model.predict([img])

    # print (prediction[0])

    predict_pic = os.path.join(FLAGS.buckets, "bird_bullocks_oriole.jpg")

    img_obj = file_io.read_file_to_string(predict_pic)

    file_io.write_string_to_file("bird_bullocks_oriole.jpg", img_obj)

    #读取图片文件,转化成RGB模式,返回(0,255)的数组

    img = scipy.ndimage.imread("bird_bullocks_oriole.jpg", mode="RGB")

 

    # Scale it to 32x32

    img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')

    #"bicubic"双三次插值

    # Predict

    prediction = model.predict([img])

    print (prediction[0])

    print (prediction[0])

    #print (prediction[0].index(max(prediction[0])))

    num=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

    print ("This is a %s"%(num[prediction[0].tolist().index(max(prediction[0]))]))

    # predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")

    # img = scipy.ndimage.imread(predict_pic, mode="RGB")

    # img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')

    # prediction = model.predict([img])

    #print (prediction[0])

 

 

if __name__ == '__main__':

#如果模块是被直接运行的,则代码块被运行,如果模块是被导入的,则代码块不被运行

    parser = argparse.ArgumentParser()

   

    parser.add_argument('--buckets', type=str, default='', help='input data path')

    

    parser.add_argument('--checkpointDir', type=str, default='',help='output model path')

FLAGS, _ = parser.parse_known_args()

#print('FLAGS1:',FLAGS)  当前存储地址

#Namespace(buckets='oss://.../.../.../.../', 

#checkpointDir='oss://.../.../.../check_point/model/')

 

tf.app.run(main=main)

#run(main=None,argv=None)tf的固定格式

#generic entry point script 通用入口点脚本

目录
相关文章
|
11天前
|
机器学习/深度学习 人工智能 自然语言处理
Hugging Face 论文平台 Daily Papers 功能全解析
【9月更文挑战第23天】Hugging Face 是一个专注于自然语言处理领域的开源机器学习平台。其推出的 Daily Papers 页面旨在帮助开发者和研究人员跟踪 AI 领域的最新进展,展示经精心挑选的高质量研究论文,并提供个性化推荐、互动交流、搜索、分类浏览及邮件提醒等功能,促进学术合作与知识共享。
|
22天前
|
机器学习/深度学习 Java API
阿里云文档智能解析——大模型版能力最佳实践与体验评测
阿里云文档智能解析(大模型版)在处理非结构化数据方面表现优异,尤其是在性能和可扩展性上具有明显优势。虽然存在一些待完善之处,但其强大的基础能力和广泛的适用场景使其成为企业数字转型过程中的有力助手。随着技术的不断进步和完善,相信它会在更多领域展现出更大的价值。
67 5
阿里云文档智能解析——大模型版能力最佳实践与体验评测
|
12天前
|
文字识别 算法 API
阿里云文档解析(大模型版)优化
阿里云文档解析(大模型版
|
9天前
|
数据挖掘 BI UED
B2B 领域 CRM 平台全景解析
在快节奏的商业环境中,移动CRM应用让企业随时随地管理客户关系,成为不可或缺的利器。本文深入探讨了七款优秀移动CRM应用:销售易Mobile、Salesforce Mobile、纷享销客、Zoho CRM Mobile、HubSpot Mobile、金蝶云·星辰移动端及用友U8+移动端,详细分析了各自的优势和适用场景。企业可根据具体需求、预算和行业特点,选择最适合的移动CRM解决方案,提升销售效率与管理水平,为企业发展注入新活力。
B2B 领域 CRM 平台全景解析
|
18天前
|
敏捷开发 安全 测试技术
软件测试的艺术:从代码到用户体验的全方位解析
本文将深入探讨软件测试的重要性和实施策略,通过分析不同类型的测试方法和工具,展示如何有效地提升软件质量和用户满意度。我们将从单元测试、集成测试到性能测试等多个角度出发,详细解释每种测试方法的实施步骤和最佳实践。此外,文章还将讨论如何通过持续集成和自动化测试来优化测试流程,以及如何建立有效的测试团队来应对快速变化的市场需求。通过实际案例的分析,本文旨在为读者提供一套系统而实用的软件测试策略,帮助读者在软件开发过程中做出更明智的决策。
|
27天前
|
API 云计算 开发者
使用宜搭平台带来的便利:技术解析与实践
【9月更文第8天】随着企业信息化建设的不断深入,业务流程自动化的需求日益增长。宜搭平台作为一种高效的应用构建工具,为企业提供了快速搭建各类业务系统的可能。本文将探讨使用宜搭平台给企业和开发者带来的便利,并通过具体的代码示例展示其优势。
60 11
|
8天前
|
SQL 人工智能 机器人
遇到的代码部份解析
/ 模拟后端返回的数据
13 0
|
9天前
|
设计模式 存储 算法
PHP中的设计模式:策略模式的深入解析与应用在软件开发的浩瀚海洋中,PHP以其独特的魅力和强大的功能吸引了无数开发者。作为一门历史悠久且广泛应用的编程语言,PHP不仅拥有丰富的内置函数和扩展库,还支持面向对象编程(OOP),为开发者提供了灵活而强大的工具集。在PHP的众多特性中,设计模式的应用尤为引人注目,它们如同精雕细琢的宝石,镶嵌在代码的肌理之中,让程序更加优雅、高效且易于维护。今天,我们就来深入探讨PHP中使用频率颇高的一种设计模式——策略模式。
本文旨在深入探讨PHP中的策略模式,从定义到实现,再到应用场景,全面剖析其在PHP编程中的应用价值。策略模式作为一种行为型设计模式,允许在运行时根据不同情况选择不同的算法或行为,极大地提高了代码的灵活性和可维护性。通过实例分析,本文将展示如何在PHP项目中有效利用策略模式来解决实际问题,并提升代码质量。
|
1月前
|
弹性计算 开发框架 数据可视化
阿里云虚拟主机和云服务器有什么区别?多角度全解析对比
阿里云虚拟主机与云服务器ECS的主要区别在于权限与灵活性。虚拟主机简化了网站搭建流程,预装常用环境,适合初级用户快速建站;而云服务器提供全面控制权,支持多样化的应用场景,如APP后端、大数据处理等,更适合具备技术能力的用户。尽管虚拟主机在价格上通常更优惠,但随着云服务器价格的下降,其性价比已超越虚拟主机,成为更具吸引力的选择。
|
21天前
|
监控 算法 数据可视化
深入解析Android应用开发中的高效内存管理策略在移动应用开发领域,Android平台因其开放性和灵活性备受开发者青睐。然而,随之而来的是内存管理的复杂性,这对开发者提出了更高的要求。高效的内存管理不仅能够提升应用的性能,还能有效避免因内存泄漏导致的应用崩溃。本文将探讨Android应用开发中的内存管理问题,并提供一系列实用的优化策略,帮助开发者打造更稳定、更高效的应用。
在Android开发中,内存管理是一个绕不开的话题。良好的内存管理机制不仅可以提高应用的运行效率,还能有效预防内存泄漏和过度消耗,从而延长电池寿命并提升用户体验。本文从Android内存管理的基本原理出发,详细讨论了几种常见的内存管理技巧,包括内存泄漏的检测与修复、内存分配与回收的优化方法,以及如何通过合理的编程习惯减少内存开销。通过对这些内容的阐述,旨在为Android开发者提供一套系统化的内存优化指南,助力开发出更加流畅稳定的应用。
43 0

热门文章

最新文章

推荐镜像

更多
下一篇
无影云桌面