TF flags的简介
1、flags可以帮助我们通过命令行来动态的更改代码中的参数。Tensorflow 使用flags定义命令行参数的方法。ML的模型中有大量需要tuning的超参数,所以此方法,迎合了需要一种灵活的方式对代码某些参数进行调整的需求
(1)、比如,在这个py文件中,首先定义了一些参数,然后将参数统一保存到变量FLAGS中,相当于赋值,后边调用这些参数的时候直接使用FLAGS参数即可
(2)、基本参数类型有三种flags.DEFINE_integer、flags.DEFINE_float、flags.DEFINE_boolean。
(3)、第一个是参数名称,第二个参数是默认值,第三个是参数描述
2、使用过程
#第一步,调用flags = tf.app.flags,进行定义参数名称,并可给定初值、参数说明
#第二步,flags参数直接赋值
#第三步,运行tf.app.run()
import tensorflow as tf
#第一个是参数名称,第二个参数是默认值,第三个是参数描述
#第一步,调用flags = tf.app.flags,进行定义参数名称,并可给定初值、参数说明
flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
FLAGS = flags.FLAGS
def main(_):
#第二步,flags参数直接赋值
pp.pprint(flags.FLAGS.__flags)
if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
……
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
if FLAGS.train:
dcgan.train(FLAGS)
else:
if not dcgan.load(FLAGS.checkpoint_dir)[0]:
raise Exception("[!] Train a model first, then run test mode")
if __name__ == '__main__':
#第三步,运行tf.app.run()
tf.app.run()
TF flags的安装
直接从TF中调用,导入即可使用
import tensorflow as tf
flags = tf.app.flags
TF flags的使用方法
1、第一步,py文件的内部函数的定义
T1、tf定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv。
import tensorflow as tf
#1、第一个是参数名称,第二个参数是默认值,第三个是参数描述
tf.app.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")
tf.app.flags.DEFINE_integer('int_name', 10,"descript2")
tf.app.flags.DEFINE_boolean('bool_name', False, "descript3")
FLAGS = tf.app.flags.FLAGS
#必须带参数,否则:'TypeError: main() takes no arguments (1 given)'; main的参数名随意定义,无要求
def main(_):
print(FLAGS.str_name)
print(FLAGS.int_name)
print(FLAGS.bool_name)
if __name__ == '__main__':
tf.app.run() #2、执行main函数
T2、一个简单的示例程序来展示如何使用 command line flags,除了使用 absl 外,还可以使用 argparser。比如定义下边文件名称为test_flags.py
from absl import flags
from absl import app
FLAGS = flags.FLAGS
#1、第一个是参数名称,第二个参数是默认值,第三个是参数描述
flags.DEFINE_string('model', None, 'model to run')
def main(argv):
print('Hello World')
print('selected model', FLAGS.model)
if __name__ == '__main__':
app.run(main) #2、执行main函数
2、第二步,在命令行中运行上边的示例程序
# 1、运行示例程序
python test_flags.py
# 2、更改相应参数
python test_flags.py --model "My model"
# 3、获得帮助信息
python test_flags.py -help
python test_flags.py -helpfull