pytorch 获取最后一个非0值 截取

之前的一个文章写过tensorflow如何实现,有可能对大家有帮助,这里用pytorch实现了一下,大概意思就是,输入1个[batch_size , seq_len ]的矩阵,目的是获取每一行中的最后一个非零元素,例如[ [1,2,3,0,0,0], [4,5,0,0,0,0]],期望获取[[3],[5]], 这个功能在srgnn之类的模型中有用到,就是获取每个交互序列的最后一个有意义值。注意序列中,0都是排在末尾的。

import torch
import time

if __name__ == "__main__":
    # 手工构造
    item_emb=torch.randint(1,10,(10,5))
    zero_emb=torch.zeros(10,4)
    item_emb = torch.cat((item_emb.float(), zero_emb), dim=1)
    
    # 获取每个序列中最后一个非0元素的位置,用sum实现
    last_index = torch.where(item_emb > 0, torch.full_like(item_emb, 1), item_emb)
    last_index = last_index.sum(dim=1)
    
    # 列表表达式,数量特别大时,效率更高
    time1 = time.time()
    last_item = [int(item_emb[index][int(x.item()) - 1].item()) for index, x in enumerate(last_index)]
    time2 = time.time()
    print((time2 - time1))

    # for循环
    time3 = time.time()
    last_item_ = []
    for index, x in enumerate(last_index):
        last_item_.append(int(item_emb[index][int(x.item()) - 1].item()))
    time4 = time.time()
    print(time4 - time3)

你可能感兴趣的:(python,深度学习)