Get To The Point: Summarization with Pointer-Generator Networks代码分析

代码有三种选项:basic attention model / point generator model / pointer-generator + coverage model

改进都在decoder端,在model.py里是通过self._add_decoder实现的,在这个函数里定义了cell和prev_coverage(只有test的时候才有),再调用封装好的在tf基础上改进的attention_decoder.py函数。

该函数接口为:

decoder_inputs: [dec_step, batch_size, emb_dim].
initial_state: [batch_size x cell.state_size].
encoder_states: [batch_size x attn_length x attn_size].
enc_padding_mask: [batch_size x attn_length]
initial_state_attention: We set this to False for train/eval mode (because we call attention_decoder once for all decoder steps) and True for decode mode (because we call attention_decoder once for each decoder step).
prev_coverage: This is only not None in decode mode when using coverage.

train

定义encoder_features

Get To The Point: Summarization with Pointer-Generator Networks代码分析_第1张图片

循环处理输入的decoder_inputs(dec_step个cell):

  1. 把input和context_vector结合在一起(time step 0时context_vector为零向量)
  2. cell_output, state = cell(x, state)
  3. 调用子函数attention获得下一个time step的context_vector, attn_dist, coverage。具体获得方法为:直接转置decoder_state得到decoder_features,在time step为0时 计算
    a即为attn_dist

    在之后计算

    coverage是attn_dist的加和

    context vector是attn_dist*encoder_states
  4. 计算
  5. cell_outputcontext_vector联合起来得到当前dec step的输出

最终该封装attention_decoder的输出是

state #最后一个cell的state
attn_dists所有attn_dist的list
p_gens同上
coverage 最后一个cell的coverage 

对于所有outputs(decoder_outputs),线性变换得到在词汇表上的分布。

根据
得到最终的vocab分布。

test

有一个placeholder self.prev_coverage [batch_size,]
decoder部分和之前不同的是,在循环开始前,用prev_coverage预先计算context_vector, _, coverage = attention(initial_state, coverage)
在循环内部不再更新coverage,其他部分没有区别。
在得到最终vocab分布之后,不是直接取得分最高的词,而是用beam search。

Beam search

在语句decoder.decode()中,调用beam_search.run_beam_search得到best_hyp。
不用想也知道是靠多重循环构成的。也用了之前的model里各种self.xxx,但是decoder的单词是一个个的,而不是一下生成整句。

  1. run_encoder得到enc_states和dec_initial_state。
  2. 定义存储每一个hypothesis的结构hyps list,包含四个相同的初始hyp(最开始只有一个)
  3. 进入循环,step代表生成单词的个数,results存储所有生成的句子(四个string)
  4. 循环:
    记录batch_size个latest_tokens、states、prev_coverage,调用decode_onestep,得到[beam_size, 2*beam_size]的topk_ids, topk_log_probs等。
    进入第二个循环,循环num_orig_hyps次,对于每个原始假设,再遍历2*beam_size,更新hyp的参数,all_hyps按顺序记录num_orig_hyps*2*beam_size个hyp。
    第二个循环结束后,选prob最大的beam_size个hyp:如果生成了完整的句子就放进results里,否则就存hyp进hyps
  5. 结束循环后选择可能性最大的result里的句子
    整体思想是保留四个最好的结果,用这个四个去发散寻找下一个单词(424个结果中找最好的4个)

你可能感兴趣的:(Get To The Point: Summarization with Pointer-Generator Networks代码分析)