阿里云平台cifar10代码解析

本文涉及的产品
对象存储 OSS,20GB 3个月
对象存储 OSS,恶意文件检测 1000次 1年
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介:

阿里提供的代码,但是没有解释,花了好长时间才把里面看了个大概,但还未完全掌握。共享我的见解,也请看的同志帮忙修正,共勉!

Tflearn  https://github.com/tflearn/tflearn

 

from __future__ import division, print_function, absolute_import __future__ 

#模块是包含python未来特性的模块,如果你用的是python2,那你就可以通过导入这个模块使用python3的特性

import tensorflow as tf

from six.moves import urllib

import tarfile

import tflearn

from tflearn.data_utils import shuffle, to_categorical

from tflearn.layers.core import input_data, dropout, fully_connected

from tflearn.layers.conv import conv_2d, max_pool_2d

from tflearn.layers.estimator import regression

from tflearn.data_preprocessing import ImagePreprocessing

from tflearn.data_augmentation import ImageAugmentation

#data_augmentation方法与data_preprocessing方法在训练阶段相似,详见#data_augmentation。对input_data方法处理

 

from tensorflow.python.lib.io import file_io

import os

import sys

import numpy as np

import pickle

import argparse

import scipy

FLAGS = None

 

def load_data(dirname, one_hot=False):

    X_train = []

    Y_train = []

 

for i in range(1, 6):

#1,2,3,4,5

        fpath = os.path.join(dirname, 'data_batch_' + str(i))

#连接:将dirname和后面的'data_batch_' + str(i)进行拼接,得到文件夹中文件的路径

        data, labels = load_batch(fpath)

#录入文件并得到data, labels

#经过解压得知 'data_batch_' + str(i)得到的是训练文件

        if i == 1:

            X_train = data

            Y_train = labels

        else:

            X_train = np.concatenate([X_train, data], axis=0)

#沿着某个轴拼接矩阵

            Y_train = np.concatenate([Y_train, labels], axis=0)

#将所有片段拼接在一起,返回的是ndarray

    fpath = os.path.join(dirname, 'test_batch')

    X_test, Y_test = load_batch(fpath)

#3通道分离shape为(10000,1024,3),为reshape做准备

#np.dstack(tup)等价于np.concatenate(tup,axis=2)即在第三维进行拼接

#(X_train[:, :1024], X_train[:, 1024:2048],X_train[:, 2048:])tup

    X_train = np.dstack((X_train[:, :1024], X_train[:, 1024:2048],

                         X_train[:, 2048:])) / 255.

    X_train = np.reshape(X_train, [-1, 32, 32, 3])

    X_test = np.dstack((X_test[:, :1024], X_test[:, 1024:2048],

                        X_test[:, 2048:])) / 255.

    #uint8  无符号整数,0 至 255,处理后,每个元素都小于等于1

    X_test = np.reshape(X_test, [-1, 32, 32, 3])

#stack(堆叠)

if one_hot:

#根据需要,看是否要转化成独热编码

        Y_train = to_categorical(Y_train, 10)

        Y_test = to_categorical(Y_test, 10)

 

    return (X_train, Y_train), (X_test, Y_test)

 

#reporthook from stackoverflow #13881092

def reporthook(blocknum, blocksize, totalsize):

    readsofar = blocknum * blocksize

    if totalsize > 0:

        percent = readsofar * 1e2 / totalsize

        s = "\r%5.1f%% %*d / %d" % (

            percent, len(str(totalsize)), readsofar, totalsize)

        sys.stderr.write(s)#重定向标准错误信息

        if readsofar >= totalsize: # near the end

            sys.stderr.write("\n")

    else: # total size is unknown

        sys.stderr.write("read %d\n" % (readsofar,))

 

def load_batch(fpath):

#录入文件路径,返回data,labels

object = file_io.read_file_to_string(fpath) 

#文件内容转化成字符串或者字节fpath需是文件路径

    #origin_bytes = bytes(object, encoding='latin1')

    # with open(fpath, 'rb') as f:

if sys.version_info > (3, 0):

#如果大于3.0版本  sys.version_info返回sys.version_info

#major=3,minor=6,micro=2,releaselevel=final,serial=0

        # Python3

        d = pickle.loads(object, encoding='latin1') 

#反序列化。。。尝试将object = file_io.read_file_to_string(fpath) 

#改成pickle.dumps()进行序列化  encoding="bytes"

 

    else:

        # Python2

        d = pickle.loads(object)

    data = d["data"]#data.shape (10000,3072)

    labels = d["labels"]

    return data, labels

 

def main(_):

dirname = os.path.join(FLAGS.buckets, "")

  #Namespace(buckets='oss://.../.../.../.../', 

  #checkpointDir='oss://.../.../.../check_point/model/')

#print('dirname:',dirname)

#dirname: oss://.../.../.../.../

    (X, Y), (X_test, Y_test) = load_data(dirname)

    print("load data done")

 

X, Y = shuffle(X, Y)

#tflearn.data_utils.shuffle*arrs)每个矩阵按第一维一致打乱

Y = to_categorical(Y, 10)

#tflearn.data_utils.to_categoricaly,nb_classes),y矩阵,nb_classes分类数

    Y_test = to_categorical(Y_test, 10)

 

    # Real-time data preprocessing

    img_prep = ImagePreprocessing()

    img_prep.add_featurewise_zero_center()#零中心分布

    img_prep.add_featurewise_stdnorm()#标准偏离 standard deviation

 

    # Real-time data augmentation

    img_aug = ImageAugmentation()

    img_aug.add_random_flip_leftright()#随机左右翻转

    img_aug.add_random_rotation(max_angle=25.)#按随机角度旋转,最大旋转角度25

 

    # Convolutional network building

    network = input_data(shape=[None, 32, 32, 3],

                         data_preprocessing=img_prep,

                         data_augmentation=img_aug)

    network = conv_2d(network, 32, 3, activation='relu')

    network = max_pool_2d(network, 2)

    network = conv_2d(network, 64, 3, activation='relu')

    network = conv_2d(network, 64, 3, activation='relu')

    network = max_pool_2d(network, 2)

    network = fully_connected(network, 512, activation='relu')

    network = dropout(network, 0.5)

    network = fully_connected(network, 10, activation='softmax')

    network = regression(network, optimizer='adam',

                         loss='categorical_crossentropy',

                         learning_rate=0.001)

 

    # Train using classifier

    model = tflearn.DNN(network, tensorboard_verbose=0)

    # model.fit(X, Y, n_epoch=100, shuffle=True, validation_set=(X_test, Y_test),

    #           show_metric=True, batch_size=96, run_id='cifar10_cnn')

    model_path = os.path.join(FLAGS.checkpointDir, "model.tfl")

print(model_path)

#print('model_path:',model_path)

##model_path: #oss://.../.../.../check_point/model/model.tf2

    model.load(model_path)

 

    # predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")

    # file_paths = tf.train.match_filenames_once(predict_pic)

    # input_file_queue = tf.train.string_input_producer(file_paths)

    # reader = tf.WholeFileReader()

    # file_path, raw_data = reader.read(input_file_queue)

    # img = tf.image.decode_jpeg(raw_data, 3)

    # img = tf.image.resize_images(img, [32, 32])

    # prediction = model.predict([img])

    # print (prediction[0])

    predict_pic = os.path.join(FLAGS.buckets, "bird_bullocks_oriole.jpg")

    img_obj = file_io.read_file_to_string(predict_pic)

    file_io.write_string_to_file("bird_bullocks_oriole.jpg", img_obj)

    #读取图片文件,转化成RGB模式,返回(0,255)的数组

    img = scipy.ndimage.imread("bird_bullocks_oriole.jpg", mode="RGB")

 

    # Scale it to 32x32

    img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')

    #"bicubic"双三次插值

    # Predict

    prediction = model.predict([img])

    print (prediction[0])

    print (prediction[0])

    #print (prediction[0].index(max(prediction[0])))

    num=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

    print ("This is a %s"%(num[prediction[0].tolist().index(max(prediction[0]))]))

    # predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")

    # img = scipy.ndimage.imread(predict_pic, mode="RGB")

    # img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')

    # prediction = model.predict([img])

    #print (prediction[0])

 

 

if __name__ == '__main__':

#如果模块是被直接运行的,则代码块被运行,如果模块是被导入的,则代码块不被运行

    parser = argparse.ArgumentParser()

   

    parser.add_argument('--buckets', type=str, default='', help='input data path')

    

    parser.add_argument('--checkpointDir', type=str, default='',help='output model path')

FLAGS, _ = parser.parse_known_args()

#print('FLAGS1:',FLAGS)  当前存储地址

#Namespace(buckets='oss://.../.../.../.../', 

#checkpointDir='oss://.../.../.../check_point/model/')

 

tf.app.run(main=main)

#run(main=None,argv=None)tf的固定格式

#generic entry point script 通用入口点脚本

目录
相关文章
|
16天前
|
机器学习/深度学习 人工智能 弹性计算
阿里云GPU服务器全解析_GPU价格收费标准_GPU优势和使用说明
阿里云GPU云服务器提供强大的GPU算力,适用于深度学习、科学计算、图形可视化和视频处理等场景。作为亚太领先的云服务商,阿里云GPU云服务器具备高灵活性、易用性、容灾备份、安全性和成本效益,支持多种实例规格,满足不同业务需求。
|
23天前
|
机器学习/深度学习 人工智能 自然语言处理
医疗行业的语音识别技术解析:AI多模态能力平台的应用与架构
AI多模态能力平台通过语音识别技术,实现实时转录医患对话,自动生成结构化数据,提高医疗效率。平台具备强大的环境降噪、语音分离及自然语言处理能力,支持与医院系统无缝集成,广泛应用于门诊记录、多学科会诊和急诊场景,显著提升工作效率和数据准确性。
|
1月前
|
存储 安全 Java
系统安全架构的深度解析与实践:Java代码实现
【11月更文挑战第1天】系统安全架构是保护信息系统免受各种威胁和攻击的关键。作为系统架构师,设计一套完善的系统安全架构不仅需要对各种安全威胁有深入理解,还需要熟练掌握各种安全技术和工具。
84 10
|
29天前
|
存储 弹性计算 NoSQL
"从入门到实践,全方位解析云服务器ECS的秘密——手把手教你轻松驾驭阿里云的强大计算力!"
【10月更文挑战第23天】云服务器ECS(Elastic Compute Service)是阿里云提供的基础云计算服务,允许用户在云端租用和管理虚拟服务器。ECS具有弹性伸缩、按需付费、简单易用等特点,适用于网站托管、数据库部署、大数据分析等多种场景。本文介绍ECS的基本概念、使用场景及快速上手指南。
71 3
|
1月前
|
前端开发 JavaScript 开发者
揭秘前端高手的秘密武器:深度解析递归组件与动态组件的奥妙,让你代码效率翻倍!
【10月更文挑战第23天】在Web开发中,组件化已成为主流。本文深入探讨了递归组件与动态组件的概念、应用及实现方式。递归组件通过在组件内部调用自身,适用于处理层级结构数据,如菜单和树形控件。动态组件则根据数据变化动态切换组件显示,适用于不同业务逻辑下的组件展示。通过示例,展示了这两种组件的实现方法及其在实际开发中的应用价值。
34 1
|
2月前
|
域名解析 网络协议
非阿里云注册域名如何在云解析DNS设置解析?
非阿里云注册域名如何在云解析DNS设置解析?
|
2月前
|
机器学习/深度学习 人工智能 算法
揭开深度学习与传统机器学习的神秘面纱:从理论差异到实战代码详解两者间的选择与应用策略全面解析
【10月更文挑战第10天】本文探讨了深度学习与传统机器学习的区别,通过图像识别和语音处理等领域的应用案例,展示了深度学习在自动特征学习和处理大规模数据方面的优势。文中还提供了一个Python代码示例,使用TensorFlow构建多层感知器(MLP)并与Scikit-learn中的逻辑回归模型进行对比,进一步说明了两者的不同特点。
73 2
|
2月前
|
存储 搜索推荐 数据库
运用LangChain赋能企业规章制度制定:深入解析Retrieval-Augmented Generation(RAG)技术如何革新内部管理文件起草流程,实现高效合规与个性化定制的完美结合——实战指南与代码示例全面呈现
【10月更文挑战第3天】构建公司规章制度时,需融合业务实际与管理理论,制定合规且促发展的规则体系。尤其在数字化转型背景下,利用LangChain框架中的RAG技术,可提升规章制定效率与质量。通过Chroma向量数据库存储规章制度文本,并使用OpenAI Embeddings处理文本向量化,将现有文档转换后插入数据库。基于此,构建RAG生成器,根据输入问题检索信息并生成规章制度草案,加快更新速度并确保内容准确,灵活应对法律与业务变化,提高管理效率。此方法结合了先进的人工智能技术,展现了未来规章制度制定的新方向。
36 3
|
25天前
|
供应链 安全 BI
CRM系统功能深度解析:为何这些平台排名靠前
本文深入解析了市场上排名靠前的CRM系统,如纷享销客、用友CRM、金蝶CRM、红圈CRM和销帮帮CRM,探讨了它们在功能性、用户体验、集成能力、数据安全和客户支持等方面的优势,以及如何满足企业的关键需求,助力企业实现数字化转型和业务增长。
|
2月前
|
运维 Cloud Native 持续交付
云原生技术解析:从IO出发,以阿里云原生为例
【10月更文挑战第24天】随着互联网技术的不断发展,传统的单体应用架构逐渐暴露出扩展性差、迭代速度慢等问题。为了应对这些挑战,云原生技术应运而生。云原生是一种利用云计算的优势,以更灵活、可扩展和可靠的方式构建和部署应用程序的方法。它强调以容器、微服务、自动化和持续交付为核心,旨在提高开发效率、增强系统的灵活性和可维护性。阿里云作为国内领先的云服务商,在云原生领域有着深厚的积累和实践。
57 0

推荐镜像

更多