1.构建float graph
按照tensorflow定义网络的方式定义网络结构,例如:
batch_norm_params = {'is_training': is_training, 'center': True, 'scale': True,'epsilon':2e-5}
with slim.arg_scope([slim.conv2d],
padding='SAME',
activation_fn=act_type,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params, weights_regularizer=slim.l2_regularizer(0.00001), # 0.0005
):
conv_split1 = slim.conv2d(data_split1, num_outputs=int(filters/group), kernel_size=kernel_size, stride=stride)
2.添加量化节点
for_quant = args.for_quant
if for_quant == 1:
print("quant train")
g = tf.get_default_graph()
tf.contrib.quantize.create_training_graph(input_graph=g, quant_delay=10)
3.构建训练流程
# step 1:get float graph
m = args.margin_m
x = tf.placeholder(tf.float32, [None, height, width, 1])
y = tf.placeholder(tf.int32, [None, 1])
loss,fc = get_graph(x,y,numclasses,m,True)
#step 2:get varibles for train
for_quant = args.for_quant
if for_quant == 1:
g = tf.get_default_graph()
tf.contrib.quantize.create_training_graph(input_graph=g, quant_delay=10)
#step 3:get dataIter
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser_tfrec, num_parallel_calls=4) #
dataset = dataset.shuffle(buffer_size=batch_size * 1000)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(epochs)
dataset = dataset.prefetch(buffer_size=batch_size * 10)#
# step 4:initial and optimizer
iterator = dataset.make_initializable_iterator() #
images, labels = iterator.get_next()
global_step = tf.Variable(0, trainable=False)
with tf.control_dependencies(update_ops):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss,global_step=global_step)
init_op = tf.global_variables_initializer()
sess.run(init_op)
#step 5:train
for epoch in range(epochs):
avg_loss = 0
sess.run(iterator.initializer, feed_dict={filenames: [rec_file]})
for i in range(total_batch):
image, label = sess.run([images, labels])
_, c= sess.run([optimizer, loss],feed_dict={x: image, y: label}) # , m: mValue
#step 6:save
saver.save(sess, saveFileName, global_step=epoch)
4.注意事项
tensorflow的训练量化,显存会比浮点训练大2-3倍,速度下降至4倍,显卡利用率低。
量化的结果通常比训练后量化效果好一些,需要调一些参数。比如quant_delay的设置,lr的设置等等。这个后续会继续做实验研究。
后续会研究一下原理。