GLMP 代码 详细注释

GLMP 缩写来自论文 GLOBAL-TO-LOCAL MEMORY POINTER NETWORKS
FOR TASK-ORIENTED DIALOGUE
下面是它代码的详细注释(已跑通)

3.1 模型

3.1.1 ContextRNN

class ContextRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout, n_layers=1):
        #初始化设置参数
		super(ContextRNN, self).__init__()      
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers     
        self.dropout = dropout
        #nn.Dropout:参数为float类型,将元素置0的概率
        self.dropout_layer = nn.Dropout(dropout)
        #nn.Embedding:参数分别为(单词个数 词向量维度 遇到PAD_token输出0)
        #此处可看出embedding词向量的维度和hidden的维度相同
        self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD_token)
		#调用pytorch中的GRU模块,设置网络为双向GRU
        self.gru = nn.GRU(hidden_size, hidden_size,
                          n_layers, dropout=dropout, bidirectional=True)
        self.W = nn.Linear(2*hidden_size, hidden_size)

    def get_state(self, bsz):
        """Get cell states and hidden states."""
        return _cuda(torch.zeros(2, bsz, self.hidden_size))

    def forward(self, input_seqs, input_lengths, hidden=None):
        # Note: we run this all at once (over multiple batches of multiple sequences)
        #contiguous函数返回一个内存连续的tensor
        #view函数返回一个tensor,必须有与原tensor相同的数据和相同数目的元素,但可以有不同的大小。
        #一个tensor必须是连续的contiguous()才能被查看。
        #两个函数联合作用将embedding的维度调整成一句一行
        embedded = self.embedding(input_seqs.contiguous().view(input_seqs.size(0),
                                                               -1).long()) 
        embedded = embedded.view(input_seqs.size()+(embedded.size(-1),))
        embedded = torch.sum(embedded, 2).squeeze(2) 
        embedded = self.dropout_layer(embedded)
        #初始化hidden
        hidden = self.get_state(input_seqs.size(1))
        if input_lengths:
            embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=False)
        outputs, hidden = self.gru(embedded, hidden)
        if input_lengths:
           outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=False)   
        hidden = self.W(torch.cat((hidden[0], hidden[1]), dim=1)).unsqueeze(0)
        outputs = self.W(outputs)
        return outputs.transpose(0,1), hidden

3.1.2 ExternalKnowledge

class ExternalKnowledge(nn.Module):
    def __init__(self, vocab, embedding_dim, hop, dropout):
        #ExternalKnowledge的工作原理类似于MemoryNetwork
        super(ExternalKnowledge, self).__init__()
        self.max_hops = hop#跳数
        self.embedding_dim = embedding_dim
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(dropout) 
        for hop in range(self.max_hops+1):#针对每一跳都初始化
            #nn.Embedding:单词个数 词向量维度 遇到PAD_token输出0
            C = nn.Embedding(vocab, embedding_dim, padding_idx=PAD_token)
            #将值用均值为0,方差为0.1的正态分布填充
            C.weight.data.normal_(0, 0.1)
            ##将一个child module添加到当前model,被添加的module可以通过name属性来获取。
            self.add_module("C_{}".format(hop), C)
        #定义查询C的方法?通过C_{i}来查询
        self.C = AttrProxy(self, "C_")
        #定义softmax为单维度的softmax
        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        #一维卷积函数Conv1d:参数分别为(进 出通道 卷积核的大小 输入的每一条边补充0的层数)
        self.conv_layer = nn.Conv1d(embedding_dim, embedding_dim, 5, padding=2)

    def add_lm_embedding(self, full_memory, kb_len, conv_len, hiddens):
        #将hiddens按照kb_len的顺序加入full_memory矩阵
        for bi in range(full_memory.size(0)):
            start, end = kb_len[bi], kb_len[bi]+conv_len[bi]
            full_memory[bi, start:end, :] = full_memory[bi, start:end, :] + hiddens[bi, :conv_len[bi], :]
        return full_memory

    def load_memory(self, story, kb_len, conv_len, hidden, dh_outputs):
        # Forward multiple hop mechanism
        #squeeze函数:把第一个维度(维度为1)挤掉
        u = [hidden.squeeze(0)]
        story_size = story.size()
        self.m_story = []
        for hop in range(self.max_hops):#循环K跳来计算Attention权重
            #求c^k_i
            #把外部知识的三元组挤成一个向量
			embed_A = self.C[hop](story.contiguous().view(story_size[0], -1))#.long()) # b * (m * s) * e
            embed_A = embed_A.view(story_size+(embed_A.size(-1),)) # b * m * s * e
            #sum函数:返回输入张量给定维度上每行的和
            embed_A = torch.sum(embed_A, 2).squeeze(2) # b * m * e
            if not args["ablationH"]:
                embed_A = self.add_lm_embedding(embed_A, kb_len, conv_len, dh_outputs)
            embed_A = self.dropout_layer(embed_A)
            
            if(len(list(u[-1].size()))==1): 
                u[-1] = u[-1].unsqueeze(0) ## used for bsz = 1.
            #调整q^k到可以与C_i相乘的维度
            u_temp = u[-1].unsqueeze(1).expand_as(embed_A)
            #将结果调整成可以进行单维度softmax的向量
            prob_logit = torch.sum(embed_A*u_temp, 2)
            #按照论文所给公式求出p^k_i
            prob_   = self.softmax(prob_logit)
            
            #重复求c^k_i的步骤求出c^k+1_i
            embed_C = self.C[hop+1](story.contiguous().view(story_size[0], -1).long())
            embed_C = embed_C.view(story_size+(embed_C.size(-1),)) 
            embed_C = torch.sum(embed_C, 2).squeeze(2)
            if not args["ablationH"]:
                embed_C = self.add_lm_embedding(embed_C, kb_len, conv_len, dh_outputs)

            #调整p的维度来与C^k+1_i相乘
            prob = prob_.unsqueeze(2).expand_as(embed_C)
            #求和得到o^k
            o_k  = torch.sum(embed_C*prob, 1)
            #q^k+1 = q^k + o^k
            u_k = u[-1] + o_k
            u.append(u_k)
            self.m_story.append(embed_A)
        self.m_story.append(embed_C)
        #返回p^k(?)和q^k+1
        return self.sigmoid(prob_logit), u[-1]

    def forward(self, query_vector, global_pointer):
        #U为查询向量
        u = [query_vector]
        #循环k跳来得出最后的查询结果
        for hop in range(self.max_hops):
            #从m_story中取出load_memory中存好的c^k
            m_A = self.m_story[hop] 
            if not args["ablationG"]:
                #global_pointer更新Global contextual representation 
                m_A = m_A * global_pointer.unsqueeze(2).expand_as(m_A) 
            if(len(list(u[-1].size()))==1): 
                u[-1] = u[-1].unsqueeze(0) ## used for bsz = 1.
            u_temp = u[-1].unsqueeze(1).expand_as(m_A)
            prob_logits = torch.sum(m_A*u_temp, 2)
            prob_soft   = self.softmax(prob_logits)
            m_C = self.m_story[hop+1] 
            if not args["ablationG"]:
                m_C = m_C * global_pointer.unsqueeze(2).expand_as(m_C)
            prob = prob_soft.unsqueeze(2).expand_as(m_C)
            o_k  = torch.sum(m_C*prob, 1)
            u_k = u[-1] + o_k
            u.append(u_k)
        return prob_soft, prob_logits

3.1.3 LocalMemoryDecoder

class LocalMemoryDecoder(nn.Module):
    def __init__(self, shared_emb, lang, embedding_dim, hop, dropout):
        #初始化网络
        super(LocalMemoryDecoder, self).__init__()
        self.num_vocab = lang.n_words
        self.lang = lang
        self.max_hops = hop
        self.embedding_dim = embedding_dim
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(dropout) 
        #将shared_emb保存为C(此C不同于ExternalKnowledge中的self.C)
        #根据GLMP中代码得出shared_emb为encoder的embedding
        self.C = shared_emb 
        self.softmax = nn.Softmax(dim=1)
		#sketch RNN用于跑出没有槽值信息但是有sketch tag的response
        self.sketch_rnn = nn.GRU(embedding_dim, embedding_dim, dropout=dropout)
        self.relu = nn.ReLU()
        self.projector = nn.Linear(2*embedding_dim, embedding_dim)
        self.conv_layer = nn.Conv1d(embedding_dim, embedding_dim, 5, padding=2)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, extKnow, story_size, story_lengths, copy_list, encode_hidden, target_batches, max_target_length, batch_size, use_teacher_forcing, get_decoded_words, global_pointer):
        # Initialize variables for vocab and pointer
        #初始化输入输出矩阵
        all_decoder_outputs_vocab = _cuda(torch.zeros(max_target_length, batch_size, self.num_vocab))
        all_decoder_outputs_ptr = _cuda(torch.zeros(max_target_length, batch_size, story_size[1]))
        decoder_input = _cuda(torch.LongTensor([SOS_token] * batch_size))
        #mask矩阵用来防止生成相同的槽
        memory_mask_for_step = _cuda(torch.ones(story_size[0], story_size[1]))
        decoded_fine, decoded_coarse = [], []
        
        hidden = self.relu(self.projector(encode_hidden)).unsqueeze(0)
        
        # Start to generate word-by-word
        for t in range(max_target_length):
            #hidden的生成在前四行,不同循环的不同变量只有decoder_input
            embed_q = self.dropout_layer(self.C(decoder_input)) # b * e
            if len(embed_q.size()) == 1: embed_q = embed_q.unsqueeze(0)
            _, hidden = self.sketch_rnn(embed_q.unsqueeze(0), hidden)
            #取sketch_RNN 的第一个hidden
            query_vector = hidden[0] 
            
            #求p^vocab
            p_vocab = self.attend_vocab(self.C.weight, hidden.squeeze(0))
            all_decoder_outputs_vocab[t] = p_vocab
            #topk函数:得到前k个元素,返回两个tensor,第一个为数值,第二个为下标
            #此处是得到数值最大的元素的下标
            #所以过不过softmax意义不大,这可能是代码把softmax注释掉的原因
            _, topvi = p_vocab.data.topk(1)
            
            # query the external konwledge using the hidden state of sketch RNN
            #通过sketch RNN的hidden来向外部知识查询
            #extKnow为ExternalKnowledge的forward函数
            prob_soft, prob_logits = extKnow(query_vector, global_pointer)
            #得到L_t:本地内存指针的位置标志
            all_decoder_outputs_ptr[t] = prob_logits

            if use_teacher_forcing:#是否使用标准答案来改变输入以改变sketchRNN生成的hidden
                #使用预先存入的生成
                decoder_input = target_batches[:,t] 
            else:
                #使用sketchRNN上次生成的output
                decoder_input = topvi.squeeze()
            
            if get_decoded_words:

                search_len = min(5, min(story_lengths))
                prob_soft = prob_soft * memory_mask_for_step
                #取前search_len元素,作为填入槽的object预备
                _, toppi = prob_soft.data.topk(search_len)
                temp_f, temp_c = [], []
                
                for bi in range(batch_size):
                    token = topvi[bi].item() #topvi[:,0][bi].item()#取下标
                    temp_c.append(self.lang.index2word[token])#取单词
                    
                    if '@' in self.lang.index2word[token]:#判断是否是槽
                        #如果是槽且符合条件,那么可以将代表的宾语加入输出
                        cw = 'UNK'
                        for i in range(search_len):
                            if toppi[:,i][bi] < story_lengths[bi]-1: 
                                cw = copy_list[bi][toppi[:,i][bi].item()]            
                                break
                        temp_f.append(cw)
                        
                        if args['record']:
                            #mask矩阵标0,防止生成相同的槽
                            memory_mask_for_step[bi, toppi[:,i][bi].item()] = 0
                    else:
                        #不是槽的话就直接输出词语
                        temp_f.append(self.lang.index2word[token])

                decoded_fine.append(temp_f)
                decoded_coarse.append(temp_c)

        return all_decoder_outputs_vocab, all_decoder_outputs_ptr, decoded_fine, decoded_coarse

    def attend_vocab(self, seq, cond):
        scores_ = cond.matmul(seq.transpose(1,0))#首先对输入的矩阵转置,然后进行矩阵乘法
        #论文中写的公式是带softmax的……但注释掉了就很秀
        # scores = F.softmax(scores_, dim=1)
        return scores_

3.2 GLMP

3.2.1 Encoder&Decoder

def encode_and_decode(self, data, max_target_length, use_teacher_forcing, get_decoded_words):
        # Build unknown mask for memory 
        #初始化mask矩阵
        if args['unk_mask'] and self.decoder.training:
            story_size = data['context_arr'].size()
            rand_mask = np.ones(story_size)
            bi_mask = np.random.binomial([np.ones((story_size[0],story_size[1]))], 1-self.dropout)[0]
            rand_mask[:,:,0] = rand_mask[:,:,0] * bi_mask
            conv_rand_mask = np.ones(data['conv_arr'].size())
            for bi in range(story_size[0]):
                start, end = data['kb_arr_lengths'][bi],  data['kb_arr_lengths'][bi] + data['conv_arr_lengths'][bi]
                conv_rand_mask[:end-start,bi,:] = rand_mask[bi,start:end,:]
            rand_mask = self._cuda(rand_mask)
            conv_rand_mask = self._cuda(conv_rand_mask)
            conv_story = data['conv_arr'] * conv_rand_mask.long()
            story = data['context_arr'] * rand_mask.long()
        else:
            story, conv_story = data['context_arr'], data['conv_arr']
        
        # Encode dialog history and KB to vectors
        #encoder为modules中的ContextRNN,extKnow为modules中的ExternalKonwledge
        dh_outputs, dh_hidden = self.encoder(conv_story, data['conv_arr_lengths'])
        global_pointer, kb_readout = self.extKnow.load_memory(story, data['kb_arr_lengths'], data['conv_arr_lengths'], dh_hidden, dh_outputs)
        #cat函数会在给定维度上对输入的张量序列进行连接操作。
        #这里将kb_memory和对话历史的信息联合到一起,组成外部知识
        encoded_hidden = torch.cat((dh_hidden.squeeze(0), kb_readout), dim=1) 
        
        # Get the words that can be copy from the memory
        #准备kbmemory中可以拷贝的宾语列表
        batch_size = len(data['context_arr_lengths'])
        self.copy_list = []
        for elm in data['context_arr_plain']:
            elm_temp = [ word_arr[0] for word_arr in elm ]
            self.copy_list.append(elm_temp) 
        
        #decoder(LocalMemoryDecoder)生成output语句
        outputs_vocab, outputs_ptr, decoded_fine, decoded_coarse = self.decoder.forward(
            self.extKnow, 
            story.size(), 
            data['context_arr_lengths'],
            self.copy_list, 
            encoded_hidden, 
            data['sketch_response'], 
            max_target_length, 
            batch_size, 
            use_teacher_forcing, 
            get_decoded_words, 
            global_pointer) 

        return outputs_vocab, outputs_ptr, decoded_fine, decoded_coarse, global_pointer

3.2.2 Evaluate

def evaluate(self, dev, matric_best, early_stop=None):
        print("STARTING EVALUATION")
        # Set to not-training mode to disable dropout
        #因为train函数的默认为True,即开启dropout
        self.encoder.train(False)
        self.extKnow.train(False)
        self.decoder.train(False)  
        
        ref, hyp = [], []
        acc, total = 0, 0
        dialog_acc_dict = {}
        F1_pred, F1_cal_pred, F1_nav_pred, F1_wet_pred = 0, 0, 0, 0
        F1_count, F1_cal_count, F1_nav_count, F1_wet_count = 0, 0, 0, 0
        pbar = tqdm(enumerate(dev),total=len(dev))
        new_precision, new_recall, new_f1_score = 0, 0, 0

        #读入数据
        if args['dataset'] == 'kvr':
            with open('data/KVR/kvret_entities.json') as f:
                global_entity = json.load(f)
                global_entity_list = []
                for key in global_entity.keys():
                    if key != 'poi':
                        global_entity_list += [item.lower().replace(' ', '_') for item in global_entity[key]]
                    else:
                        for item in global_entity['poi']:
                            global_entity_list += [item[k].lower().replace(' ', '_') for k in item.keys()]
                global_entity_list = list(set(global_entity_list))

        for j, data_dev in pbar: 
            # Encode and Decode
            _, _, decoded_fine, decoded_coarse, global_pointer = self.encode_and_decode(data_dev, self.max_resp_len, False, True)
            decoded_coarse = np.transpose(decoded_coarse)
            decoded_fine = np.transpose(decoded_fine)
            for bi, row in enumerate(decoded_fine):#各种计算
                st = ''
                for e in row:
                    if e == 'EOS': break
                    else: st += e + ' '
                st_c = ''
                for e in decoded_coarse[bi]:
                    if e == 'EOS': break
                    else: st_c += e + ' '
                pred_sent = st.lstrip().rstrip()
                pred_sent_coarse = st_c.lstrip().rstrip()
                gold_sent = data_dev['response_plain'][bi].lstrip().rstrip()
                ref.append(gold_sent)
                hyp.append(pred_sent)
                
                if args['dataset'] == 'kvr': 
                    # compute F1 SCORE
                    #计算F-评论的结果
                    single_f1, count = self.compute_prf(data_dev['ent_index'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_pred += single_f1
                    F1_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_cal'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_cal_pred += single_f1
                    F1_cal_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_nav'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_nav_pred += single_f1
                    F1_nav_count += count
                    single_f1, count = self.compute_prf(data_dev['ent_idx_wet'][bi], pred_sent.split(), global_entity_list, data_dev['kb_arr_plain'][bi])
                    F1_wet_pred += single_f1
                    F1_wet_count += count
                else:
                    # compute Dialogue Accuracy Score
                    #计算对话准确性结果
                    current_id = data_dev['ID'][bi]
                    if current_id not in dialog_acc_dict.keys():
                        dialog_acc_dict[current_id] = []
                    if gold_sent == pred_sent:
                        dialog_acc_dict[current_id].append(1)
                    else:
                        dialog_acc_dict[current_id].append(0)

                # compute Per-response Accuracy Score
                #计算每个回应的准确性
                total += 1
                if (gold_sent == pred_sent):
                    acc += 1

                if args['genSample']:
                    self.print_examples(bi, data_dev, pred_sent, pred_sent_coarse, gold_sent)

        # Set back to training mode
        #开启dropput防止过拟合
        self.encoder.train(True)
        self.extKnow.train(True)
        self.decoder.train(True)

        bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True)
        acc_score = acc / float(total)
        print("ACC SCORE:\t"+str(acc_score))

        if args['dataset'] == 'kvr':
            F1_score = F1_pred / float(F1_count)
            print("F1 SCORE:\t{}".format(F1_pred/float(F1_count)))
            print("\tCAL F1:\t{}".format(F1_cal_pred/float(F1_cal_count))) 
            print("\tWET F1:\t{}".format(F1_wet_pred/float(F1_wet_count))) 
            print("\tNAV F1:\t{}".format(F1_nav_pred/float(F1_nav_count))) 
            print("BLEU SCORE:\t"+str(bleu_score))
        else:
            dia_acc = 0
            for k in dialog_acc_dict.keys():
                if len(dialog_acc_dict[k])==sum(dialog_acc_dict[k]):
                    dia_acc += 1
            print("Dialog Accuracy:\t"+str(dia_acc*1.0/len(dialog_acc_dict.keys())))
        
        if (early_stop == 'BLEU'):
            if (bleu_score >= matric_best):
                self.save_model('BLEU-'+str(bleu_score))
                print("MODEL SAVED")
            return bleu_score
        elif (early_stop == 'ENTF1'):
            if (F1_score >= matric_best):
                self.save_model('ENTF1-{:.4f}'.format(F1_score))
                print("MODEL SAVED")  
            return F1_score
        else:
            if (acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(acc_score))
                print("MODEL SAVED")
            return acc_score

3.2.3 Compute_prf

def compute_prf(self, gold, pred, global_entity_list, kb_plain):
        local_kb_word = [k[0] for k in kb_plain]
        TP, FP, FN = 0, 0, 0
        if len(gold)!= 0:
            count = 1
            for g in gold:
                if g in pred:
                    TP += 1
                else:
                    FN += 1
            for p in set(pred):
                if p in global_entity_list or p in local_kb_word:
                    if p not in gold:
                        FP += 1
            #计算准确率
            precision = TP / float(TP+FP) if (TP+FP)!=0 else 0
            #计算召回率
            recall = TP / float(TP+FN) if (TP+FN)!=0 else 0
            #F-评价,综合准确率和召回率的评价指标
            F1 = 2 * precision * recall / float(precision + recall) if (precision+recall)!=0 else 0
        else:
            precision, recall, F1, count = 0, 0, 0, 0
        return F1, count

你可能感兴趣的:(python,指针,深度学习)