Python不支持读取oss数据,因此所有调用python的 Open()、 os.path.exist() 等文件和文件夹操作的函数的代码都无法执行。如Scipy.misc.imread()、numpy.load()等。
通常采用以下两种办法在机器学习平台读取数据。
使用tf.gfile下的函数,适用于简单地读取一张图片,或者一个文本等,成员函数如下。
tf.gfile.Copy(oldpath, newpath, overwrite=False) # 拷贝文件
tf.gfile.DeleteRecursively(dirname) # 递归删除目录下所有文件
tf.gfile.Exists(filename) # 文件是否存在
tf.gfile.FastGFile(name, mode='r') # 无阻塞读取文件
tf.gfile.GFile(name, mode='r') # 读取文件
tf.gfile.Glob(filename) # 列出文件夹下所有文件, 支持pattern
tf.gfile.IsDirectory(dirname) # 返回dirname是否为一个目录
tf.gfile.ListDirectory(dirname) # 列出dirname下所有文件
tf.gfile.MakeDirs(dirname) # 在dirname下创建一个文件夹, 如果父目录不存在, 会自动创建父目录. 如果
文件夹已经存在, 且文件夹可写, 会返回成功
tf.gfile.MkDir(dirname) # 在dirname处创建一个文件夹
tf.gfile.Remove(filename) # 删除filename
tf.gfile.Rename(oldname, newname, overwrite=False) # 重命名
tf.gfile.Stat(dirname) # 返回目录的统计数据
tf.gfile.Walk(top, inOrder=True) # 返回目录的文件树
具体请参考tf.gfile模块。
使用tf.gfile.Glob、tf.gfile.FastGFile、 tf.WhoFileReader() 、tf.train.shuffer_batch(),适用于批量读取文件(读取文件之前需要获取文件列表,如果是批量读取,还需要创建batch)。
使用机器学习搭建深度学习实验时,通常需要在界面右侧设置读取目录、代码文件等参数。这些参数通过“—XXX”(XXX代表字符串)的形式传入,tf.flags提供了这个功能。
import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('buckets', 'oss://{OSS Bucket}/', '训练图片所在文件夹')
tf.flags.DEFINE_string('batch_size', '15', 'batch大小')
files = tf.gfile.Glob(os.path.join(FLAGS.buckets,'*.jpg')) # 如我想列出buckets下所有jpg文件路径
小规模读取文件时建议使用tf.gfile.FastGfile()。
for path in files:
file_content = tf.gfile.FastGFile(path, 'rb').read() # 一定记得使用rb读取, 不然很多情况下都会报错
image = tf.image.decode_jpeg(file_content, channels=3) # 本教程以JPG图片为例
大批量读取文件时建议使用tf.WhoFileReader()。
reader = tf.WholeFileReader() # 实例化一个reader
fileQueue = tf.train.string_input_producer(files) # 创建一个供reader读取的队列
file_name, file_content = reader.read(fileQueue) # 使reader从队列中读取一个文件
image_content = tf.image.decode_jpeg(file_content, channels=3) # 讲读取结果解码为图片
label = XXX # 这里省略处理label的过程
batch = tf.train.shuffle_batch([label, image_content], batch_size=FLAGS.batch_size, num_threads=4,
capacity=1000 + 3 * FLAGS.batch_size, min_after_dequeue=1000)
sess = tf.Session() # 创建Session
tf.train.start_queue_runners(sess=sess) # 重要!!! 这个函数是启动队列, 不加这句线程会一直阻塞
labels, images = sess.run(batch) # 获取结果
部分代码解释如下:
tf.train.string_input_producer:把files转换成一个队列,并且需要 tf.train.start_queue_runners 来启动队列。
tf.train.shuffle_batch参数解释如下:
batch_size:批处理大小。即每次运行这个batch,返回的数据个数。
num_threads:运行线程数,一般设置为4。
capacity:随机取文件范围。比如数据集有10000个数据,需要从5000个数据中随机抽取,那么capacity就设置成5000。
min_after_dequeue:维持队列的最小长度,不能大于capacity。
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。