前端智能化漫谈 (3) - pix2code推理部分解析

前端智能化漫谈 (3) - pix2code推理部分解析

上一节我们将pix2code的流程梳理了一遍,相信大家已经都可以跑起来了。
在谈pix2code的算法改进之前,关于训练和推理过程还有若干细节我们还需要进一步讨论一下。

onehot编码

上次我们讲到create_binary_representation,就是将单词转换成one hot编码。

    def create_binary_representation(self):
        if sys.version_info >= (3,):
            items = self.vocabulary.items()
        else:
            items = self.vocabulary.iteritems()
        for key, value in items:
            binary = np.zeros(self.size)
            binary[value] = 1
            self.binary_vocabulary[key] = binary

生成的one hot编码如下所示,有20个token,所以每一个token是一个20维的向量:

-> 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
-> 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
 -> 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
stack-> 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
{-> 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
-> 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
row-> 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
switch-> 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
}-> 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
label-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
,-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.
btn-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.
radio-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.
footer-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.
btn-home-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.
btn-notifications-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.
check-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.
slider-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.
btn-dashboard-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.
btn-search-> 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.

上面这个存在bin目录下的words.vocab文件中,是我们的词表。将来预测出来的是这20项中每一项的概率,最终我们还要反查词表来确定内容。

其中几个常量定义如下:

START_TOKEN = ""
END_TOKEN = ""
PLACEHOLDER = " "
SEPARATOR = '->'

推理

解释了one hot编码之后,训练部分的逻辑大家应该比较清楚了。但是对于如何利用训练好的结果进行推理,可能还是有疑惑。所以我们下面讲解下推理的过程。

首先是加载模型,并且读取图片:

meta_dataset = np.load("{}/meta_dataset.npy".format(trained_weights_path))
input_shape = meta_dataset[0]
output_size = meta_dataset[1]

model = pix2code(input_shape, output_size, trained_weights_path)
model.load(trained_model_name)

sampler = Sampler(trained_weights_path, input_shape, output_size, CONTEXT_LENGTH)

file_name = basename(input_path)[:basename(input_path).find(".")]
evaluation_img = Utils.get_preprocessed_img(input_path, IMAGE_SIZE)

然后是核心的搜索逻辑。搜索模式有两种,一种是greedy模式,另一种是指定深度的beam search:

if search_method == "greedy":
    result, _ = sampler.predict_greedy(model, np.array([evaluation_img]))
    print("Result greedy: {}".format(result))
else:
    beam_width = int(search_method)
    print("Search with beam width: {}".format(beam_width))
    result, _ = sampler.predict_beam_search(model, np.array([evaluation_img]), beam_width=beam_width)
    print("Result beam: {}".format(result))

贪婪法预测

我们以datasets/android/eval_set/00EAE181-AF3B-4CBD-928A-561FF6F4345F.png为例说明一下搜索过程。
前端智能化漫谈 (3) - pix2code推理部分解析_第1张图片

我们先看下greedy模式下的玩法:

    def predict_greedy(self, model, input_img, require_sparse_label=True, sequence_length=150, verbose=True):
        current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1)
        current_context.append(self.voc.vocabulary[START_TOKEN])
        if require_sparse_label:
            current_context = Utils.sparsify(current_context, self.output_size)

current_context先用占位符填满,然后变成one hot编码,变成下面这样:

[[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

然后我们继续往下看:

        predictions = START_TOKEN
        out_probas = []

        for i in range(0, sequence_length):
            if verbose:
                print("predicting {}/{}...".format(i, sequence_length))

如果想看进度条,就可以将verbose开关打开。打开之后的过程是下面这样的:

predicting 0/150...
stack
predicting 1/150...
stack{
predicting 2/150...
stack{

predicting 3/150...
stack{
row
predicting 4/150...
stack{
row{
predicting 5/150...
stack{
row{

predicting 6/150...
stack{
row{
btn-dashboard

下面开始进入预测的部分:

            probas = model.predict(input_img, np.array([current_context]))
            prediction = np.argmax(probas)
            out_probas.append(probas)

probas是根据训练结果计算的在当前情况下的对于20个标签的概率情况。
例如第一步时假设是这样的:

[3.5372774e-10 1.5484280e-15 3.7052250e-14 9.9999964e-01 4.1679780e-07
 2.5872908e-14 3.7647421e-15 4.0136690e-14 7.4600746e-14 2.8354120e-13
 3.3076447e-13 2.4662578e-15 2.2618040e-13 2.2254247e-14 4.2132672e-10
 2.3889449e-10 1.8945574e-13 1.2084879e-13 1.7556612e-10 3.4896117e-10]

此时有99.99%的概率是第3个,prediction的结果为3。

下面我们开始处理将新预测的结果添加到current_context里面,因为RNN网络是根据序列预测的。我们先建立一个new_context,然后将current_context的所有内容上移一位,也就是将第0行不要了,从第1行开始取到最后一行:

            new_context = []
            for j in range(1, self.context_length):
                new_context.append(current_context[j])

下面将我们刚才新预测出来的添加到最后一行:

            if require_sparse_label:
                sparse_label = np.zeros(self.output_size)
                sparse_label[prediction] = 1
                new_context.append(sparse_label)
            else:
                new_context.append(prediction)

            current_context = new_context

接着,我们去词汇表里查一下prediction所对应的单词,记得我们本讲开头所介绍的词表么,就是查那个表:

            predictions += self.voc.token_lookup[prediction]

最后,我们判断一下预测结果是不是到了END_TOKEN,如果是的话就结束本轮预测。

            if self.voc.token_lookup[prediction] == END_TOKEN:
                break

        return predictions, out_probas

Beam Search法预测

上一节我们讲了贪婪法,就是每一步都取最大的概率值来进行推理。但是其实我们的目标并不是让每一步的值最大,而是让这个序列的概率和最大。贪婪法获取的值未必是最大值,所以我们引入一种新的方法,Beam search。
贪婪法每次只选一个最可能的,Beam search是可以设定一个宽度,每次容纳这么多的单词进行搜索。

Beam search前面的部分和greedy时是一样的,除了最后一句创建一个BeamSearch模块之外:

    def predict_beam_search(self, model, input_img, beam_width=3, require_sparse_label=True, sequence_length=150):
        predictions = START_TOKEN
        out_probas = []

        current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1)
        current_context.append(self.voc.vocabulary[START_TOKEN])
        if require_sparse_label:
            current_context = Utils.sparsify(current_context, self.output_size)

        beam = BeamSearch(beam_width=beam_width)

下面开始调用递归建立beam search树的逻辑:

        self.recursive_beam_search(model, input_img, current_context, beam, beam.root, sequence_length)

建树的最开始,还是先把训练好的概率值读取出来。只不过greedy时只要一个argmax值就可以了,而这次我们先都要:

    def recursive_beam_search(self, model, input_img, current_context, beam, current_node, sequence_length):
        probas = model.predict(input_img, np.array([current_context]))

        predictions = []
        for i in range(0, len(probas)):
            predictions.append((i, probas[i], probas))

下面创建树节点:

        nodes = []
        for i in range(0, len(predictions)):
            prediction = predictions[i][0]
            score = predictions[i][1]
            output_probas = predictions[i][2]
            nodes.append(Node(prediction, score, output_probas))

        beam.add_nodes(current_node, nodes)

节点建好之后,进行剪枝操作。如果剪完枝发现只是一个叶子节点,或者是最大子树对应的是END_TOKEN则过程结束。

        if beam.is_valid():
            beam.prune_leaves()
            if sequence_length == 1 or self.voc.token_lookup[beam.root.max_child().key] == END_TOKEN:
                return

然后根据每个叶子节点去建子树。还是需要注意,RNN预测的输入是序列,所以在进入子树之前需要先把序列准备好:

            for node in beam.get_leaves():
                prediction = node.key

                new_context = []
                for j in range(1, self.context_length):
                    new_context.append(current_context[j])
                sparse_label = np.zeros(self.output_size)
                sparse_label[prediction] = 1
                new_context.append(sparse_label)

                self.recursive_beam_search(model, input_img, new_context, beam, node, sequence_length - 1)

下面我们仍然以…/datasets/android/eval_set/00EAE181-AF3B-4CBD-928A-561FF6F4345F.png为例进行说明。

标准答案为:

stack {
row {
btn, switch
}
row {
check
}
}
footer {
btn-dashboard, btn-search
}

在第一级时,输出的树结果如下:

 0 key= 3 ,value= 0.9999996423721313 ,level= 1
parent= root
 0 key= 4 ,value= 4.1679780338199635e-07 ,level= 1
parent= root
 0 key= 14 ,value= 4.2132672350980727e-10 ,level= 1
parent= root

第一级时的三个最可能序列为:
stack
{
switch

当然stack以绝对优势当选。
到了第二层,结果如下:

 0 key= 4 ,value= 0.9999775886614515 ,level= 2
parent= 3
 0 key= 5 ,value= 1.4085913299997205e-05 ,level= 2
parent= 3
 0 key= 18 ,value= 5.58778656594962e-06 ,level= 2
parent= 3

三个节点仍然都是stack节点派生出来的,因为权重太高了。
第二层的beam search结果为:
stack{
stack回车
stack btn-notifications

到了第三层,beam search优于greedy的效果出现了,三个最高候选中出现了一个非greedy的结果:

 0 key= 5 ,value= 0.9999684097518723 ,level= 3
parent= 4
 0 key= 11 ,value= 6.167459579019292e-06 ,level= 3
parent= 4
 0 key= 5 ,value= 1.4085780645431549e-05 ,level= 3
parent= 5

翻译一下就是:
stack { 回车
stack { ,
stack 回车 回车

第四层:

 0 key= 6 ,value= 0.9999655488193034 ,level= 4
parent= 5
 0 key= 6 ,value= 6.167428699843352e-06 ,level= 4
parent= 11
 0 key= 6 ,value= 1.4085740345689864e-05 ,level= 4
parent= 5

翻译一下就是:
stack { 换行 row
stack { , row
stack 换行 换行 row

最后我们看下第22层,第1种情况:14.67的概率

 0 key= 5 ,value= 0.14667593772000911 ,level= 22
stack{
row{
btn-dashboard,footer
}
row{slider
switch,btn-notifications
}

第2种情况:8.1%的概率

 0 key= 5 ,value= 0.08098176016926652 ,level= 22
stack{
row{
btn-dashboard,footer
}
row{
	


switch,btn-notifications
}

第3种情况,5.4%概率

 0 key= 5 ,value= 0.05431855658408171 ,level= 22
stack{
row{
btn-dashboard,footer
}
row{
slider


,btn-notifications
}}

对比下greedy时候的结果:

Result greedy: stack{
row{
btn-dashboard,footer
}
row{
slider
}
}
check{
switch,btn-notifications
}

结果对比

下面是训练loss为0.03左右时的结果。
标准答案为:

stack {
row {
btn, switch
}
row {
check
}
}
footer {
btn-dashboard, btn-search
}

Greedy时效果:

Result greedy: stack{
row{
btn,switch
}
row{
check
}
}
footer{
btn-dashboard,btn-search
}

Beam search宽度为3时的结果:

Result beam: stack{
row{
btn,switch
}
row{
check
}
}
footer{
btn-dashboard,btn-search
}

你可能感兴趣的:(深度学习)