MTCNN(Tensorflow)学习记录(为PNet生成tfrecord文件)

上一篇博客是对于两个数据集的合并,这部分内容是通过合并的数据集生成tfrecord文件。

1 为PNet生成tfrecord文件

进入prepare_data文件夹打开gen_PNet_tfrecords.py,代码如下:

#coding:utf-8
#首先导入各种库
import os
import random
import sys
import time

import tensorflow as tf

from prepare_data.tfrecord_utils import _process_image_withoutcoder, _convert_to_example_simple


def _add_to_tfrecord(filename, image_example, tfrecord_writer):
    """Loads data from image and annotations files and add them to a TFRecord.

    Args:
      filename: Dataset directory;
      name: Image name to add to the TFRecord;
      tfrecord_writer: The TFRecord writer to use for writing.
    """
    #从图片和注释文件里加载数据并将其添加到TFRecord里
    #filename: 数据目录
    #image_example: 数据,为字典形式,里面有三个key
    #tfrecord_writer:with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer
    #print('---', filename)

    #imaga_data:array to string
    #height:original image's height
    #width:original image's width
    #image_example dict contains image's info
    #image_data:转化成了字符串的图片
    #height:图片原始高度
    #width:图片原始宽度
    #image_example字典包含图片的信息
    image_data, height, width = _process_image_withoutcoder(filename)  
    example = _convert_to_example_simple(image_example, image_data)
    tfrecord_writer.write(example.SerializeToString())
    #TFRecord制作结束


def _get_output_filename(output_dir, name, net):             #获得一个输出的文件名
    #st = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    #return '%s/%s_%s_%s.tfrecord' % (output_dir, name, net, st)
    return '%s/train_PNet_landmark.tfrecord' % (output_dir) #返回的是'../../DATA/imglists/PNet/train_PNet_landmark.tfrecord'
    

def run(dataset_dir, net, output_dir, name='MTCNN', shuffling=False):
    """Runs the conversion operation.

    Args:
      dataset_dir: The dataset directory where the dataset is stored.
      output_dir: Output directory.
    """
    
    #tfrecord name 
    tf_filename = _get_output_filename(output_dir, name, net)   #'../../DATA/imglists/PNet/train_PNet_landmark.tfrecord'
    if tf.gfile.Exists(tf_filename):                            #判断是否存在同名文件
        print('Dataset files already exist. Exiting without re-creating them.')
        return
    # GET Dataset, and shuffling.
    dataset = get_dataset(dataset_dir, net=net)                 #列表dataset
    # filenames = dataset['filename']
    if shuffling:
        tf_filename = tf_filename + '_shuffle'                  #shuffling=True 
        #random.seed(12345454)
        random.shuffle(dataset)                     		      #打乱dataset的顺序
    # Process dataset files.
    # write the data to tfrecord
    print('lala')                                               #打印'lala'
    with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
        for i, image_example in enumerate(dataset):             #读取dataset的索引和内容
            if (i+1) % 100 == 0:
                sys.stdout.write('\r>> %d/%d images has been converted' % (i+1, len(dataset)))                                
                #sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(dataset)))
            sys.stdout.flush()                 				  #刷新输出
            filename = image_example['filename']                #赋值
            _add_to_tfrecord(filename, image_example, tfrecord_writer) 
    # Finally, write the labels file:
    # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
    # dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
    print('\nFinished converting the MTCNN dataset!')


def get_dataset(dir, net='PNet'):
    #get file name , label and anotation
    #item = 'imglists/PNet/train_%s_raw.txt' % net
    #获取文件名字,标签和注释
    #item =  'imglists/PNet/train_PNet_landmark.txt'
    item = 'imglists/PNet/train_%s_landmark.txt' % net
    
    dataset_dir = os.path.join(dir, item)     #dataset_dir = '../../DATA/imglists/PNet/train_PNet_landmark.txt'
    #print(dataset_dir)
    imagelist = open(dataset_dir, 'r')        #以只读的形式打开train_PNet_landmark.txt,并传入imagelist里面

    dataset = []								#新建列表
    for line in imagelist.readlines():  		#读取imagelist里面的内容
        info = line.strip().split(' ')		#去除每一行首尾的空格并且以空格为分隔符读取内容到info里面
        data_example = dict()				#新建字典
        bbox = dict()						
        data_example['filename'] = info[0]	#filename=info[0]
        #print(data_example['filename'])
        data_example['label'] = int(info[1])  #label=info[1],info[1]的值有四种可能,1,0,-1,-2;分别对应着正、负、无关、关键点样本。
        bbox['xmin'] = 0                      #初始化bounding box的值
        bbox['ymin'] = 0
        bbox['xmax'] = 0
        bbox['ymax'] = 0
        bbox['xlefteye'] = 0         		#初始化人脸坐标的值
        bbox['ylefteye'] = 0
        bbox['xrighteye'] = 0
        bbox['yrighteye'] = 0
        bbox['xnose'] = 0
        bbox['ynose'] = 0
        bbox['xleftmouth'] = 0
        bbox['yleftmouth'] = 0
        bbox['xrightmouth'] = 0
        bbox['yrightmouth'] = 0        
        if len(info) == 6:    			  #info的长度等于6时,表示此时的info是正样本或者无关样本,详情请看学习记录(一)的文末
            bbox['xmin'] = float(info[2])
            bbox['ymin'] = float(info[3])
            bbox['xmax'] = float(info[4])
            bbox['ymax'] = float(info[5])
        if len(info) == 12:                #info长度等于12时,表示此时的info是landmark样本
            bbox['xlefteye'] = float(info[2])
            bbox['ylefteye'] = float(info[3])
            bbox['xrighteye'] = float(info[4])
            bbox['yrighteye'] = float(info[5])
            bbox['xnose'] = float(info[6])
            bbox['ynose'] = float(info[7])
            bbox['xleftmouth'] = float(info[8])
            bbox['yleftmouth'] = float(info[9])
            bbox['xrightmouth'] = float(info[10])
            bbox['yrightmouth'] = float(info[11])
            
        data_example['bbox'] = bbox    #将bbox值传入字典
        dataset.append(data_example)   #将字典data_example传入列表dataset

    return dataset                     #返回dataset,datase是个列表,但是里面的每个元素都是一个字典,每个字典有3个key,分别是filename、label和bbox。

if __name__ == '__main__':
    dir = '../../DATA/'
    net = 'PNet'
    output_directory = '../../DATA/imglists/PNet'
    run(dir, net, output_directory, shuffling=True)

用到的函数有_process_image_withoutcoder、 _convert_to_example_simple。
到此TFRecord数据的制作就结束了,这个脚本在../../DATA/imglists/PNet/目录输出了一个train_PNet_landmark.tfrecord文件。

你可能感兴趣的:(MTCNN(Tensorflow)学习记录(为PNet生成tfrecord文件))