首先需要获取图像数据,因为图像没有经过0,1处理,而且只有三张照片循环输入,效果不是很好,但是文章主要把重点放在使用slim实现vgg16复现和tensorboard的使用。
slim相关学习资料:
TensorFlow-Slim使用方法说明
Tensorboard相关学习资料:
TensorBoard:可视化学习
代码如下:
'''
created on January 5 16:36 2018
@author:lhy
'''
import tensorflow as tf
import tensorflow.contrib.slim as slim
LEARNING_RATE_BASE=0.01
LEARNING_RATE_DECAY=0.99
path_list=['A.png','B.png','C.png']
#加入了标签,在使用的时候可以直接对应标签取出数据
label=[0,1,2]
#转换成张量tensor类型
img_path=tf.convert_to_tensor(path_list,dtype=tf.string)
label=tf.convert_to_tensor(label,dtype=tf.int32)
#返回了一个包含路径和标签的列表,并将文件名和对应的标签放入文件名对列中,等待系统调用
image=tf.train.slice_input_producer([img_path,label],shuffle=True,num_epochs=None)#shuffle=Flase表示不打乱,当为True的时候打乱顺序放入文件名队列
labels=image[1]
def load_image():
file_contents=tf.read_file(image[0])
img=tf.image.convert_image_dtype(tf.image.decode_png(file_contents,channels=3),tf.float32)
#img=tf.image.decode_png(file_contents,channels=3)
img=tf.image.resize_images(img,size=(228,228))
return img
img=load_image()
labels = tf.one_hot(labels, 3)#设置one_hot编码
img_batch,label_batch=tf.train.batch([img,labels],batch_size=2)
def build_graph(inputs,label):
#使用统一的参数进行搭建前向传播,使用slim十分方便快捷
with slim.arg_scope([slim.conv2d,slim.fully_connected],activation_fn=tf.nn.relu,weights_initializer=tf.truncated_normal_initializer(0.0,0.01),weights_regularizer=slim.l2_regularizer(0.0005)):
net=slim.repeat(inputs,2,slim.conv2d,64,[3,3],scope='conv1')#两次卷积,卷积核3*3*64,步长1
net=slim.max_pool2d(net,[2,2],scope='pool1')#maxpool
net=slim.repeat(net,2,slim.conv2d,128,[3,3],scope='conv2')#两次卷积,卷积核3*3*128,步长1
net=slim.max_pool2d(net,[2,2],scope='pool2')#maxpool
net=slim.repeat(net,3,slim.conv2d,256,[3,3],scope='conv3')#三次卷积,卷积核3*3*256,步长1
net=slim.max_pool2d(net,[2,2],scope='pool3')#maxpool
net=slim.repeat(net,3,slim.conv2d,512,[3,3],scope='conv4')#三次卷积,卷积核3*3*512
net=slim.max_pool2d(net,[2,2],scope='pool4')#maxpool
net=slim.repeat(net,3,slim.conv2d,512,[3,3],scope='conv5')#三层卷积,卷积核3*3*512
net=slim.max_pool2d(net,[2,2],scope='pool5')
net=slim.fully_connected(net,4096,scope='fc6')
net=slim.dropout(net,0.99,scope='dropout6')
net=slim.fully_connected(net,4096,scope='fc7')
net=slim.dropout(net,0.99,scope='dropout7')
net=slim.flatten(net)
net=slim.fully_connected(net,3,activation_fn=None,scope='fc8')#这里重载了activation_fn=None,这个函数默认使用relu激活函数
global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
loss = tf.reduce_mean(tf.square(net-label))#损失函数
accuracy=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(net,1),tf.argmax(label,1)),tf.float32))#准确率
learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,2,LEARNING_RATE_DECAY,staircase=True)#指数衰减学习率
train_op=tf.train.AdamOptimizer(learning_rate).minimize(loss,global_step=global_step)#进行反向传播
probabilities=tf.nn.softmax(net)#softmax将所有结果归一
tf.summary.scalar('loss',loss)#使用tensorboard可视化参数,将loss放如summary中
tf.summary.scalar('accuracy',accuracy)
merged_summary_op=tf.summary.merge_all()#将所有scalar到的参数放如merged_summary_op
return net,loss,train_op,probabilities,accuracy,global_step,merged_summary_op
with tf.Session() as sess:
out,loss_val,train_op,prob,accuracy_rate,step,summary = build_graph(inputs=img_batch, label=label_batch)
#initializer for num_epochs
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
train_writer=tf.summary.FileWriter('./log/train',sess.graph)#initializer一个writer,将sess图放如summary中
coord=tf.train.Coordinator()
thread=tf.train.start_queue_runners(sess=sess,coord=coord)
try:
while not coord.should_stop():
out_val,loss,backward_train,probabilities,accuracy_val,step_val,summary_val=sess.run([out,loss_val,train_op,prob,accuracy_rate,step,summary])
#print(out_val)
train_writer.add_summary(summary_val,step_val)#以step_val为横坐标,记录下summary_val中的loss和accuracy的值
print('after '+str(step_val)+'steps'+'loss is '+str(loss)+'accuracy is '+str(accuracy_val))
except tf.errors.OutOfRangeError:
print('Done')
finally:
coord.request_stop()
coord.join(thread)
可以看到log中的文件:
类型是LAPTOP-NOEBVTUG文件,这个文件类型每个人都会不同,因为这个名字是自己电脑的名字。
打开cmd,在命令行输入tensorboard --logdir=log/to/path --host=127.0.0.1
可以看到如下的显示:
注意:路径名一定不要有中文!!!否则会找不到文件!
使用浏览器(在这里用的chrome)打开http://127.0.0.1:6006
在这里就可以看到损失值(loss)和准确率(accuracy)的显示(这里输入数据太差了,导致损失值和准确率的图像都很难看)
也可以看搭建网络的图:
可以下载下来完整的图:
通过Tendorboard就可以得知图的结构和参数训练过程的情况。