前言:
{
这段时间一直在更新论文的阅读记录,文字太多,这次的形式就改成了实践记录。
本次的目标是在VOC2012数据集上实现多目标识别。
}
正文:
{
最近发现了一个目前目标识别效果比较好的新网络:PNASNet[1],就把它拿来改改。不求很好的效果,只是想做做实践换换口味。本次的计划是从PNASNet中的一些部位引出一些张量,配合着最后的瓶颈层一起进行多目标识别。
PNASNet的论文我没有细看,只知道其是自动生成的一种结构,包含了一种新的单元:PNASCell[1],PNASCell由分为normal和reduction两种(后置会将输入张量的长宽减半)。在我找到的其基于tentorflow的代码[2]中,网络在一层卷积层后一共有14个PNASCell,前2个PNASCell为reduction类,后面每3个normal类就会加一个reduction类。
VOC2012数据集中一共有20个大类,标签都是以xml文件储存,所以这里涉及到了Python下对xml的解析。代码1的打印结果为一个xml文件中所有object的使用name的文本。
#代码1
import xml.etree.ElementTree as ET
domtree = ET.parse(xml_path+'/'+xml_name)
domroot = domtree.getroot()
objects = domroot.iter('object')
for object in objects:
print(object.find('name').text) #打印object项里的所有name项中的文字
而且还涉及到了one_hot向量的生成。见代码2。
#代码2
def numpy_one_hot(number, max):
result = np.ones(max) * 0
result[number] = 1
return result
这次先写出了数据格式转换的代码,见代码3。
#代码3
# -- coding: utf-8 --
'''
此文件实现JPG和xml标签到TFRecord的转换(VOC2012版本),其中对每个JPG源文件都会生成一个TFRecord目标文件。
文件名“test_data_28.TFRecord”表示测试数据的第29个文件(第一个文件的文件号为0)。
'''
import glob
import os.path
import tensorflow as tf
import numpy as np
import gc
from tensorflow.python.platform import gfile
import xml.etree.ElementTree as ET
import dictionary
#下函数会由JPG和xml文件得到npy文件
def JPG_xml_to_TFRecord(input_JPG_path, input_xml_path, output_file_path = "data", validation_data_ratio = 0.1,
test_data_ratio = 0.1):
file_list = []
file_labels = []
#获取所有文件和其标签
extensions = ["jpg", "jpeg"]
current_label = 0
for extension in extensions:
file_glob = glob.glob(input_JPG_path+"/*."+extension) #不分大小写
file_list.extend(file_glob) #添加文件路径到file_list
for file_path in file_glob:
file_name = os.path.splitext(os.path.basename(file_path))[0]
domtree = ET.parse(input_xml_path+'/'+file_name+'.xml')
domroot = domtree.getroot()
objects = domroot.iter('object')
one_hot = np.ones(dictionary.num_classes) * 0
for object in objects:
current_one_hot = dictionary.class_dictionary[object.find('name').text]
one_hot = np.logical_or(one_hot,current_one_hot)
file_labels.append(one_hot) #添加标签向量到file_labels。
#shuffle the samples
state = np.random.get_state()
np.random.shuffle(file_list)
np.random.set_state(state)
np.random.shuffle(file_labels)
#save the samples to files
traning_count = 0
test_count = 0
validation_count = 0
iteration_times = 0
sess = tf.Session() #获取图片数据时会用到
for file_name in file_list:
print("label=" + str(file_labels[iteration_times]) + " file_path=" + file_name) #打印当前储存的文件和标签
image_values = tf.image.decode_jpeg(gfile.FastGFile(file_name, "rb").read())
image_values = sess.run(image_values)
chance = np.random.random_sample()
if chance < validation_data_ratio:
writer = tf.python_io.TFRecordWriter(output_file_path+"/validation_data_"+str(validation_count)+".TFRecord") #①
validation_count += 1
elif chance < (validation_data_ratio + test_data_ratio):
writer = tf.python_io.TFRecordWriter(output_file_path+"/test_data_"+str(test_count)+".TFRecord")
test_count += 1
else:
writer = tf.python_io.TFRecordWriter(output_file_path+"/training_data_"+str(traning_count)+".TFRecord")
traning_count += 1
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_values.tostring()])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=file_labels[iteration_times])),
"shape": tf.train.Feature(int64_list=tf.train.Int64List(value=image_values.shape)),
}))
writer.write(example.SerializeToString())
iteration_times += 1
gc.collect()
sess.close()
'''
① if you have not created the folder in the TFRecordWriter's input path yet, there will be an assertionError without any explanation.
'''
其中dictionary是字典,见代码4。
#代码4
# -- coding: utf-8 --
'''
This file is of a mapping dictionary for VOC2012。
'''
import numpy as np
num_classes = 20
#one_hot generation function
def numpy_one_hot(number, max):
result = np.ones(max) * 0
result[number] = 1
return result
#dictionary of 20 classes
class_dictionary = {}
class_dictionary['person'] = numpy_one_hot(0, num_classes)
class_dictionary['bird'] = numpy_one_hot(1, num_classes)
class_dictionary['cat'] = numpy_one_hot(2, num_classes)
class_dictionary['cow'] = numpy_one_hot(3, num_classes)
class_dictionary['dog'] = numpy_one_hot(4, num_classes)
class_dictionary['horse'] = numpy_one_hot(5, num_classes)
class_dictionary['sheep'] = numpy_one_hot(6, num_classes)
class_dictionary['aeroplane'] = numpy_one_hot(7, num_classes)
class_dictionary['bicycle'] = numpy_one_hot(8, num_classes)
class_dictionary['boat'] = numpy_one_hot(9, num_classes)
class_dictionary['bus'] = numpy_one_hot(10, num_classes)
class_dictionary['car'] = numpy_one_hot(11, num_classes)
class_dictionary['motorbike'] = numpy_one_hot(12, num_classes)
class_dictionary['train'] = numpy_one_hot(13, num_classes)
class_dictionary['bottle'] = numpy_one_hot(14, num_classes)
class_dictionary['chair'] = numpy_one_hot(15, num_classes)
class_dictionary['diningtable'] = numpy_one_hot(16, num_classes)
class_dictionary['pottedplant'] = numpy_one_hot(17, num_classes)
class_dictionary['sofa'] = numpy_one_hot(18, num_classes)
class_dictionary['tvmonitor'] = numpy_one_hot(19, num_classes)
图1是生成数据时的输出。
图1标签保留ture和fales是因为它们两个就是1和0[3]。
}
结语:
{
3天更新一篇完整的博客对我来说有忙不过来,所以这次干脆就把这次的目标实现分开,也算尽量3天一更。
因为想传到国外的网站,使用最近开始试着用英文注释。但是由于之前的注释都是中文,我又不想都改了,所有代码中有中文注释也有英文注释。
[1]Progressive Neural Architecture Search (https://arxiv.org/abs/1712.00559)
[2]https://github.com/chenxi116/PNASNet.TF
[3]https://www.cnblogs.com/qiaojushuang/p/8524811.html
}