关于torch.index_select()和torch.gather()函数的使用和区别

前言:

因为和人大合作一个项目,人大小哥哥给我原来的代码做了个简化,因此想记录一下,关于torch.gather()这个函数,感觉突然通了

应用场景

主要是在input:(batch_size,seq_len,embedding_dim)作为输入,进入gru以后返回也是(batch_size,seq_len,embedding_dim),但是由于有padding_id,只想拿到第item_list_len返回的那个隐藏层。

代码

def forward(self, interaction):
        #TODO behavior_list_emb = concat(item,catgory)
        item_list_emb = self.item_list_embedding(interaction[self.ITEM_ID_LIST])
        position_list_emb = self.position_list_embedding(interaction[self.POSITION_ID])
        behavior_list_emb = item_list_emb + position_list_emb
        short_term_intent_temp, _ = self.gru_layers(behavior_list_emb)
        short_term_intent_temp = self.gather_indexes(short_term_intent_temp, interaction[self.ITEM_LIST_LEN] - 1)
        predict_behavior_emb = self.layer_norm(short_term_intent_temp)
        return predict_behavior_emb

    def gather_indexes(self, gru_output, gather_index):
        "Gathers the vectors at the spexific positions over a minibatch"
        gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, self.embedding_size)
        output_tensor = gru_output.gather(dim=1, index=gather_index)
        return output_tensor.squeeze(1)

简单的一个demo的例子解释上面的code:

再贴两个写的很好的博客链接(以后供自己review,hhh):
先看第一个大概理解
第二个看
最后看官网的例子,基本就可以理解了

你可能感兴趣的:(pytorch,NLP,pytorch,index_select,torch.gather)