阿里提供的代码,但是没有解释,花了好长时间才把里面看了个大概,但还未完全掌握。共享我的见解,也请看的同志帮忙修正,共勉!
Tflearn https://github.com/tflearn/tflearn
from __future__ import division, print_function, absolute_import __future__
#模块是包含python未来特性的模块,如果你用的是python2,那你就可以通过导入这个模块使用python3的特性
import tensorflow as tf
from six.moves import urllib
import tarfile
import tflearn
from tflearn.data_utils import shuffle, to_categorical
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
from tflearn.data_augmentation import ImageAugmentation
#data_augmentation方法与data_preprocessing方法在训练阶段相似,详见#data_augmentation。对input_data方法处理
from tensorflow.python.lib.io import file_io
import os
import sys
import numpy as np
import pickle
import argparse
import scipy
FLAGS = None
def load_data(dirname, one_hot=False):
X_train = []
Y_train = []
for i in range(1, 6):
#1,2,3,4,5
fpath = os.path.join(dirname, 'data_batch_' + str(i))
#连接:将dirname和后面的'data_batch_' + str(i)进行拼接,得到文件夹中文件的路径
data, labels = load_batch(fpath)
#录入文件并得到data, labels
#经过解压得知 'data_batch_' + str(i)得到的是训练文件
if i == 1:
X_train = data
Y_train = labels
else:
X_train = np.concatenate([X_train, data], axis=0)
#沿着某个轴拼接矩阵
Y_train = np.concatenate([Y_train, labels], axis=0)
#将所有片段拼接在一起,返回的是ndarray
fpath = os.path.join(dirname, 'test_batch')
X_test, Y_test = load_batch(fpath)
#3通道分离shape为(10000,1024,3),为reshape做准备
#np.dstack(tup)等价于np.concatenate(tup,axis=2)即在第三维进行拼接
#(X_train[:, :1024], X_train[:, 1024:2048],X_train[:, 2048:])为tup
X_train = np.dstack((X_train[:, :1024], X_train[:, 1024:2048],
X_train[:, 2048:])) / 255.
X_train = np.reshape(X_train, [-1, 32, 32, 3])
X_test = np.dstack((X_test[:, :1024], X_test[:, 1024:2048],
X_test[:, 2048:])) / 255.
#uint8 无符号整数,0 至 255,处理后,每个元素都小于等于1
X_test = np.reshape(X_test, [-1, 32, 32, 3])
#stack(堆叠)
if one_hot:
#根据需要,看是否要转化成独热编码
Y_train = to_categorical(Y_train, 10)
Y_test = to_categorical(Y_test, 10)
return (X_train, Y_train), (X_test, Y_test)
#reporthook from stackoverflow #13881092
def reporthook(blocknum, blocksize, totalsize):
readsofar = blocknum * blocksize
if totalsize > 0:
percent = readsofar * 1e2 / totalsize
s = "\r%5.1f%% %*d / %d" % (
percent, len(str(totalsize)), readsofar, totalsize)
sys.stderr.write(s)#重定向标准错误信息
if readsofar >= totalsize: # near the end
sys.stderr.write("\n")
else: # total size is unknown
sys.stderr.write("read %d\n" % (readsofar,))
def load_batch(fpath):
#录入文件路径,返回data,labels
object = file_io.read_file_to_string(fpath)
#文件内容转化成字符串或者字节,fpath需是文件路径
#origin_bytes = bytes(object, encoding='latin1')
# with open(fpath, 'rb') as f:
if sys.version_info > (3, 0):
#如果大于3.0版本 sys.version_info返回sys.version_info
#(major=3,minor=6,micro=2,releaselevel=’final’,serial=0)
# Python3
d = pickle.loads(object, encoding='latin1')
#反序列化。。。尝试将object = file_io.read_file_to_string(fpath)
#改成pickle.dumps()进行序列化 encoding="bytes"?
else:
# Python2
d = pickle.loads(object)
data = d["data"]#data.shape (10000,3072)
labels = d["labels"]
return data, labels
def main(_):
dirname = os.path.join(FLAGS.buckets, "")
#Namespace(buckets='oss://.../.../.../.../',
#checkpointDir='oss://.../.../.../check_point/model/')
#print('dirname:',dirname)
#dirname: oss://.../.../.../.../
(X, Y), (X_test, Y_test) = load_data(dirname)
print("load data done")
X, Y = shuffle(X, Y)
#tflearn.data_utils.shuffle(*arrs)每个矩阵按第一维一致打乱
Y = to_categorical(Y, 10)
#tflearn.data_utils.to_categorical(y,nb_classes),y矩阵,nb_classes分类数
Y_test = to_categorical(Y_test, 10)
# Real-time data preprocessing
img_prep = ImagePreprocessing()
img_prep.add_featurewise_zero_center()#零中心分布
img_prep.add_featurewise_stdnorm()#标准偏离 standard deviation
# Real-time data augmentation
img_aug = ImageAugmentation()
img_aug.add_random_flip_leftright()#随机左右翻转
img_aug.add_random_rotation(max_angle=25.)#按随机角度旋转,最大旋转角度25
# Convolutional network building
network = input_data(shape=[None, 32, 32, 3],
data_preprocessing=img_prep,
data_augmentation=img_aug)
network = conv_2d(network, 32, 3, activation='relu')
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2)
network = fully_connected(network, 512, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 10, activation='softmax')
network = regression(network, optimizer='adam',
loss='categorical_crossentropy',
learning_rate=0.001)
# Train using classifier
model = tflearn.DNN(network, tensorboard_verbose=0)
# model.fit(X, Y, n_epoch=100, shuffle=True, validation_set=(X_test, Y_test),
# show_metric=True, batch_size=96, run_id='cifar10_cnn')
model_path = os.path.join(FLAGS.checkpointDir, "model.tfl")
print(model_path)
#print('model_path:',model_path)
##model_path: #oss://.../.../.../check_point/model/model.tf2
model.load(model_path)
# predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")
# file_paths = tf.train.match_filenames_once(predict_pic)
# input_file_queue = tf.train.string_input_producer(file_paths)
# reader = tf.WholeFileReader()
# file_path, raw_data = reader.read(input_file_queue)
# img = tf.image.decode_jpeg(raw_data, 3)
# img = tf.image.resize_images(img, [32, 32])
# prediction = model.predict([img])
# print (prediction[0])
predict_pic = os.path.join(FLAGS.buckets, "bird_bullocks_oriole.jpg")
img_obj = file_io.read_file_to_string(predict_pic)
file_io.write_string_to_file("bird_bullocks_oriole.jpg", img_obj)
#读取图片文件,转化成RGB模式,返回(0,255)的数组
img = scipy.ndimage.imread("bird_bullocks_oriole.jpg", mode="RGB")
# Scale it to 32x32
img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')
#"bicubic"双三次插值
# Predict
prediction = model.predict([img])
print (prediction[0])
print (prediction[0])
#print (prediction[0].index(max(prediction[0])))
num=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
print ("This is a %s"%(num[prediction[0].tolist().index(max(prediction[0]))]))
# predict_pic = os.path.join(FLAGS.buckets, "bird_mount_bluebird.jpg")
# img = scipy.ndimage.imread(predict_pic, mode="RGB")
# img = scipy.misc.imresize(img, (32, 32), interp="bicubic").astype(np.float32, casting='unsafe')
# prediction = model.predict([img])
#print (prediction[0])
if __name__ == '__main__':
#如果模块是被直接运行的,则代码块被运行,如果模块是被导入的,则代码块不被运行。
parser = argparse.ArgumentParser()
parser.add_argument('--buckets', type=str, default='', help='input data path')
parser.add_argument('--checkpointDir', type=str, default='',help='output model path')
FLAGS, _ = parser.parse_known_args()
#print('FLAGS1:',FLAGS) 当前存储地址
#Namespace(buckets='oss://.../.../.../.../',
#checkpointDir='oss://.../.../.../check_point/model/')
tf.app.run(main=main)
#run(main=None,argv=None)tf的固定格式
#generic entry point script 通用入口点脚本