上一篇博客是为PNet网络生成TFRecord文件,现在开始对PNet进行训练。
进入train_models
文件夹打开train_PNet.py
,代码如下:
#coding:utf-8
from train_models.mtcnn_model import P_Net
from train_models.train import train
def train_PNet(base_dir, prefix, end_epoch, display, lr):
"""
train PNet
:param dataset_dir: tfrecord path
:param prefix:
:param end_epoch: max epoch for training
:param display:
:param lr: learning rate
:return:
"""
#base_dir: tfrecord文件的路径
#prefix:'../data/MTCNN_model/PNet_landmark/PNet'
#end_epoch:训练的最大周期
#display:100
#lr:学习率
net_factory = P_Net
train(net_factory,prefix, end_epoch, base_dir, display=display, base_lr=lr)
if __name__ == '__main__':
#data path
base_dir = '../../DATA/imglists/PNet'
model_name = 'MTCNN'
#model_path = '../data/%s_model/PNet/PNet' % model_name
#with landmark
model_path = '../data/%s_model/PNet_landmark/PNet' % model_name
prefix = model_path
end_epoch = 30
display = 100
lr = 0.001
train_PNet(base_dir, prefix, end_epoch, display, lr)
由上可以看出调用了P_Net
和train
这两个函数,我们在这里将这两个导出来,train
函数的代码如下:
def train(net_factory, prefix, end_epoch, base_dir,
display=200, base_lr=0.01):
"""
train PNet/RNet/ONet
:param net_factory:
:param prefix: model path
:param end_epoch:
:param dataset:
:param display:
:param base_lr:
:return:
"""
#net_factory:P_Net函数
#prefix:'../data/MTCNN_model/PNet_landmark/PNet'
#end_epoch:30
#base_dir:'../../DATA/imglists/PNet',tfrecord文件的路径
#displah:传进来的值是100
#lr:传进来的值是0.001
net = prefix.split('/')[-1] #net=PNet
#label file
label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
#'../../DATA/imglists/PNet/train_PNet_landmark.txt'
print(label_file) #打印路径'../../DATA/imglists/PNet/train_PNet_landmark.txt'
f = open(label_file, 'r') #打开train_PNet_landmark.txt
# get number of training examples
num = len(f.readlines()) #计算总数据量,num=1429422
print("Total size of the dataset is: ", num)
print(prefix) #打印路径'../data/MTCNN_model/PNet_landmark/PNet'
#PNet use this method to get data
if net == 'PNet':
dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
#dataset_dir='../../DATA/imglists/PNet/train_PNet_landmark.tfrecord_shuffle'
print('dataset dir is:',dataset_dir)
#打印路径
image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
#读取tfrecord文件
#RNet use 3 tfrecords to get data
else:
pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
#landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
landmark_dir = os.path.join('../../DATA/imglists/RNet','landmark_landmark.tfrecord_shuffle')
dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
assert pos_batch_size != 0,"Batch Size Error "
part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
assert part_batch_size != 0,"Batch Size Error "
neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
assert neg_batch_size != 0,"Batch Size Error "
landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
assert landmark_batch_size != 0,"Batch Size Error "
batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
#print('batch_size is:', batch_sizes)
image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)
#landmark_dir
if net == 'PNet':
image_size = 12
radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
elif net == 'RNet':
image_size = 24
radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
else:
radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1;
image_size = 48
#define placeholder
#输入图片,形式为[384,12,12,3]
input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
#输入label,形式为[384]
label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
#输入bbox,形式为[384,4]
bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
#输入landmark,形式为[384,10]
landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
#get loss and accuracy
#对图片进行色彩调整
input_image = image_color_distort(input_image)
#获得人类分类训练、bounding box训练、landmark训练、正则化的损失和人脸分类训练的准确率
cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
#train,update learning rate(3 loss)
#将三个损失乘上各自的权重再加上正则化损失得到总损失
total_loss_op = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op
#得到模型的学习率和train_op
train_op, lr_op = train_model(base_lr,
total_loss_op,
num)
# 将所有变量初始化
init = tf.global_variables_initializer()
sess = tf.Session()
#save model
saver = tf.train.Saver(max_to_keep=0)
sess.run(init)
#visualize some variables
#使用Tensorbroad,可视化这些数据的变化
tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
tf.summary.scalar("total_loss",total_loss_op)#cls_loss, bbox loss, landmark loss and L2 loss add together
#保存所有的summary
summary_op = tf.summary.merge_all()
#保存目录的创建
logs_dir = "../logs/%s" %(net)
if os.path.exists(logs_dir) == False:
os.mkdir(logs_dir)
#将文件写进目录
writer = tf.summary.FileWriter(logs_dir,sess.graph)
#通过projector.ProjectorConfig()类来帮助生成日志文件
projector_config = projector.ProjectorConfig()
#将projector的内容写入日志文件
projector.visualize_embeddings(writer,projector_config)
#begin
#使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象
coord = tf.train.Coordinator()
#begin enqueue thread
#启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
i = 0
#total steps
#总批次
MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
epoch = 0
#防止内存溢出
sess.graph.finalize()
try:
for step in range(MAX_STEP):
i = i + 1
#使用 coord.should_stop()来查询是否应该终止所有线程,
#当文件队列(queue)中的所有文件都已经读取出列的时候,
#会抛出一个 OutofRangeError 的异常,这时候就应该停止Sesson中的所有线程
if coord.should_stop():
break
#数据读取
image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
#random flip
#随机翻转图片
image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
'''
print('im here')
print(image_batch_array.shape)
print(label_batch_array.shape)
print(bbox_batch_array.shape)
print(landmark_batch_array.shape)
print(label_batch_array[0])
print(bbox_batch_array[0])
print(landmark_batch_array[0])
'''
_,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
#每过200个step就打印时间和各种损失
if (step+1) % display == 0:
#acc = accuracy(cls_pred, labels_batch)
cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})
total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_landmark_loss*landmark_loss + L2_loss
# landmark loss: %4f,
print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (
datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss,landmark_loss, L2_loss,total_loss, lr))
#save every two epochs
#每两个周期保存一次
if i * config.BATCH_SIZE > num*2:
epoch = epoch + 1
i = 0
path_prefix = saver.save(sess, prefix, global_step=epoch*2)
print('path prefix is :', path_prefix)
writer.add_summary(summary,global_step=step)
except tf.errors.OutOfRangeError:
print("完成!!!")
finally:
coord.request_stop()
writer.close()
coord.join(threads)
sess.close()
用到了read_single_tfrecord函数、P_Net(在这个脚本里以net_factory的形式存在)函数、train_model()函数、random_flip_images()函数