上一节我们将pix2code的流程梳理了一遍,相信大家已经都可以跑起来了。
在谈pix2code的算法改进之前,关于训练和推理过程还有若干细节我们还需要进一步讨论一下。
上次我们讲到create_binary_representation,就是将单词转换成one hot编码。
def create_binary_representation(self):
if sys.version_info >= (3,):
items = self.vocabulary.items()
else:
items = self.vocabulary.iteritems()
for key, value in items:
binary = np.zeros(self.size)
binary[value] = 1
self.binary_vocabulary[key] = binary
生成的one hot编码如下所示,有20个token,所以每一个token是一个20维的向量:
-> 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
-> 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
-> 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
stack-> 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
{-> 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
-> 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
row-> 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
switch-> 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
}-> 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
label-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
,-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.
btn-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
radio-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.
footer-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.
btn-home-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.
btn-notifications-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.
check-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.
slider-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.
btn-dashboard-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.
btn-search-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.
上面这个存在bin目录下的words.vocab文件中,是我们的词表。将来预测出来的是这20项中每一项的概率,最终我们还要反查词表来确定内容。
其中几个常量定义如下:
START_TOKEN = ""
END_TOKEN = ""
PLACEHOLDER = " "
SEPARATOR = '->'
解释了one hot编码之后,训练部分的逻辑大家应该比较清楚了。但是对于如何利用训练好的结果进行推理,可能还是有疑惑。所以我们下面讲解下推理的过程。
首先是加载模型,并且读取图片:
meta_dataset = np.load("{}/meta_dataset.npy".format(trained_weights_path))
input_shape = meta_dataset[0]
output_size = meta_dataset[1]
model = pix2code(input_shape, output_size, trained_weights_path)
model.load(trained_model_name)
sampler = Sampler(trained_weights_path, input_shape, output_size, CONTEXT_LENGTH)
file_name = basename(input_path)[:basename(input_path).find(".")]
evaluation_img = Utils.get_preprocessed_img(input_path, IMAGE_SIZE)
然后是核心的搜索逻辑。搜索模式有两种,一种是greedy模式,另一种是指定深度的beam search:
if search_method == "greedy":
result, _ = sampler.predict_greedy(model, np.array([evaluation_img]))
print("Result greedy: {}".format(result))
else:
beam_width = int(search_method)
print("Search with beam width: {}".format(beam_width))
result, _ = sampler.predict_beam_search(model, np.array([evaluation_img]), beam_width=beam_width)
print("Result beam: {}".format(result))
我们以datasets/android/eval_set/00EAE181-AF3B-4CBD-928A-561FF6F4345F.png为例说明一下搜索过程。
我们先看下greedy模式下的玩法:
def predict_greedy(self, model, input_img, require_sparse_label=True, sequence_length=150, verbose=True):
current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1)
current_context.append(self.voc.vocabulary[START_TOKEN])
if require_sparse_label:
current_context = Utils.sparsify(current_context, self.output_size)
current_context先用占位符填满,然后变成one hot编码,变成下面这样:
[[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
然后我们继续往下看:
predictions = START_TOKEN
out_probas = []
for i in range(0, sequence_length):
if verbose:
print("predicting {}/{}...".format(i, sequence_length))
如果想看进度条,就可以将verbose开关打开。打开之后的过程是下面这样的:
predicting 0/150...
stack
predicting 1/150...
stack{
predicting 2/150...
stack{
predicting 3/150...
stack{
row
predicting 4/150...
stack{
row{
predicting 5/150...
stack{
row{
predicting 6/150...
stack{
row{
btn-dashboard
下面开始进入预测的部分:
probas = model.predict(input_img, np.array([current_context]))
prediction = np.argmax(probas)
out_probas.append(probas)
probas是根据训练结果计算的在当前情况下的对于20个标签的概率情况。
例如第一步时假设是这样的:
[3.5372774e-10 1.5484280e-15 3.7052250e-14 9.9999964e-01 4.1679780e-07
2.5872908e-14 3.7647421e-15 4.0136690e-14 7.4600746e-14 2.8354120e-13
3.3076447e-13 2.4662578e-15 2.2618040e-13 2.2254247e-14 4.2132672e-10
2.3889449e-10 1.8945574e-13 1.2084879e-13 1.7556612e-10 3.4896117e-10]
此时有99.99%的概率是第3个,prediction的结果为3。
下面我们开始处理将新预测的结果添加到current_context里面,因为RNN网络是根据序列预测的。我们先建立一个new_context,然后将current_context的所有内容上移一位,也就是将第0行不要了,从第1行开始取到最后一行:
new_context = []
for j in range(1, self.context_length):
new_context.append(current_context[j])
下面将我们刚才新预测出来的添加到最后一行:
if require_sparse_label:
sparse_label = np.zeros(self.output_size)
sparse_label[prediction] = 1
new_context.append(sparse_label)
else:
new_context.append(prediction)
current_context = new_context
接着,我们去词汇表里查一下prediction所对应的单词,记得我们本讲开头所介绍的词表么,就是查那个表:
predictions += self.voc.token_lookup[prediction]
最后,我们判断一下预测结果是不是到了END_TOKEN,如果是的话就结束本轮预测。
if self.voc.token_lookup[prediction] == END_TOKEN:
break
return predictions, out_probas
上一节我们讲了贪婪法,就是每一步都取最大的概率值来进行推理。但是其实我们的目标并不是让每一步的值最大,而是让这个序列的概率和最大。贪婪法获取的值未必是最大值,所以我们引入一种新的方法,Beam search。
贪婪法每次只选一个最可能的,Beam search是可以设定一个宽度,每次容纳这么多的单词进行搜索。
Beam search前面的部分和greedy时是一样的,除了最后一句创建一个BeamSearch模块之外:
def predict_beam_search(self, model, input_img, beam_width=3, require_sparse_label=True, sequence_length=150):
predictions = START_TOKEN
out_probas = []
current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1)
current_context.append(self.voc.vocabulary[START_TOKEN])
if require_sparse_label:
current_context = Utils.sparsify(current_context, self.output_size)
beam = BeamSearch(beam_width=beam_width)
下面开始调用递归建立beam search树的逻辑:
self.recursive_beam_search(model, input_img, current_context, beam, beam.root, sequence_length)
建树的最开始,还是先把训练好的概率值读取出来。只不过greedy时只要一个argmax值就可以了,而这次我们先都要:
def recursive_beam_search(self, model, input_img, current_context, beam, current_node, sequence_length):
probas = model.predict(input_img, np.array([current_context]))
predictions = []
for i in range(0, len(probas)):
predictions.append((i, probas[i], probas))
下面创建树节点:
nodes = []
for i in range(0, len(predictions)):
prediction = predictions[i][0]
score = predictions[i][1]
output_probas = predictions[i][2]
nodes.append(Node(prediction, score, output_probas))
beam.add_nodes(current_node, nodes)
节点建好之后,进行剪枝操作。如果剪完枝发现只是一个叶子节点,或者是最大子树对应的是END_TOKEN则过程结束。
if beam.is_valid():
beam.prune_leaves()
if sequence_length == 1 or self.voc.token_lookup[beam.root.max_child().key] == END_TOKEN:
return
然后根据每个叶子节点去建子树。还是需要注意,RNN预测的输入是序列,所以在进入子树之前需要先把序列准备好:
for node in beam.get_leaves():
prediction = node.key
new_context = []
for j in range(1, self.context_length):
new_context.append(current_context[j])
sparse_label = np.zeros(self.output_size)
sparse_label[prediction] = 1
new_context.append(sparse_label)
self.recursive_beam_search(model, input_img, new_context, beam, node, sequence_length - 1)
下面我们仍然以…/datasets/android/eval_set/00EAE181-AF3B-4CBD-928A-561FF6F4345F.png为例进行说明。
标准答案为:
stack {
row {
btn, switch
}
row {
check
}
}
footer {
btn-dashboard, btn-search
}
在第一级时,输出的树结果如下:
0 key= 3 ,value= 0.9999996423721313 ,level= 1
parent= root
0 key= 4 ,value= 4.1679780338199635e-07 ,level= 1
parent= root
0 key= 14 ,value= 4.2132672350980727e-10 ,level= 1
parent= root
第一级时的三个最可能序列为:
stack
{
switch
当然stack以绝对优势当选。
到了第二层,结果如下:
0 key= 4 ,value= 0.9999775886614515 ,level= 2
parent= 3
0 key= 5 ,value= 1.4085913299997205e-05 ,level= 2
parent= 3
0 key= 18 ,value= 5.58778656594962e-06 ,level= 2
parent= 3
三个节点仍然都是stack节点派生出来的,因为权重太高了。
第二层的beam search结果为:
stack{
stack回车
stack btn-notifications
到了第三层,beam search优于greedy的效果出现了,三个最高候选中出现了一个非greedy的结果:
0 key= 5 ,value= 0.9999684097518723 ,level= 3
parent= 4
0 key= 11 ,value= 6.167459579019292e-06 ,level= 3
parent= 4
0 key= 5 ,value= 1.4085780645431549e-05 ,level= 3
parent= 5
翻译一下就是:
stack { 回车
stack { ,
stack 回车 回车
第四层:
0 key= 6 ,value= 0.9999655488193034 ,level= 4
parent= 5
0 key= 6 ,value= 6.167428699843352e-06 ,level= 4
parent= 11
0 key= 6 ,value= 1.4085740345689864e-05 ,level= 4
parent= 5
翻译一下就是:
stack { 换行 row
stack { , row
stack 换行 换行 row
最后我们看下第22层,第1种情况:14.67的概率
0 key= 5 ,value= 0.14667593772000911 ,level= 22
stack{
row{
btn-dashboard,footer
}
row{slider
switch,btn-notifications
}
第2种情况:8.1%的概率
0 key= 5 ,value= 0.08098176016926652 ,level= 22
stack{
row{
btn-dashboard,footer
}
row{
switch,btn-notifications
}
第3种情况,5.4%概率
0 key= 5 ,value= 0.05431855658408171 ,level= 22
stack{
row{
btn-dashboard,footer
}
row{
slider
,btn-notifications
}}
对比下greedy时候的结果:
Result greedy: stack{
row{
btn-dashboard,footer
}
row{
slider
}
}
check{
switch,btn-notifications
}
下面是训练loss为0.03左右时的结果。
标准答案为:
stack {
row {
btn, switch
}
row {
check
}
}
footer {
btn-dashboard, btn-search
}
Greedy时效果:
Result greedy: stack{
row{
btn,switch
}
row{
check
}
}
footer{
btn-dashboard,btn-search
}
Beam search宽度为3时的结果:
Result beam: stack{
row{
btn,switch
}
row{
check
}
}
footer{
btn-dashboard,btn-search
}