torch.gather的三维实例

函数解析

torch.gather(input,dim,index)
先从二维开始:

>>> a = torch.arange(9).reshape(3,3)
>>> a
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
>>> index= torch.LongTensor([[2,1,0]]) # 这里index和input必须有一样的维度...
>>> torch.gather(a,0,index)
tensor([[6, 4, 2]])

torch.gather的三维实例_第1张图片
这里为了简单把index只做了13,如果是23,就类似:
torch.gather的三维实例_第2张图片

这里dim的意义是在哪个维度上去取值,这意味着:对index和input来说,除了这个维度以外,其他的维度必须一样。

高维的gather

如果维度成为三维或者更高维,也是类似的。
torch.gather的三维实例_第3张图片


手动broadcast

先上一段简单的实例代码。

>>> a
tensor([[ 0.9918,  0.4911,  1.4912, -1.8491],
        [ 0.1257, -0.4406,  0.3371,  0.1205],
        [ 0.3064, -0.8198,  1.2851,  0.2486]])
>>> b
tensor([[0, 1],
        [1, 2],
        [2, 2]])
>>> a.unsqueeze(1).expand(3,2,4).gather(dim=0,index=b.unsqueeze(2).expand(3,2,4))
tensor([[[ 0.9918,  0.4911,  1.4912, -1.8491],
         [ 0.1257, -0.4406,  0.3371,  0.1205]],

        [[ 0.1257, -0.4406,  0.3371,  0.1205],
         [ 0.3064, -0.8198,  1.2851,  0.2486]],

        [[ 0.3064, -0.8198,  1.2851,  0.2486],
         [ 0.3064, -0.8198,  1.2851,  0.2486]]])

上面这个例子是我自己在写一个东西的时候要用。简单描述一下需求,a是input,b是index,但是和gather定义不一样的是b选择的是一整行而不是单个元素
以实例为例,b中的0是在a中取出一行。tf.gather(a,b)就得到了。

本来我要做的事情是把tf的代码改成pytorch,相对来说tf的代码显得简单得多,因为tf有自动的broadcast功能,而pytorch必须自己手动扩展维度,在三维的时候tf代码就很简单了。


其实这就是embedding,不知道pytorch的embedding怎么实现的,所以更加简洁优雅的写法:

emb = torch.nn.Embedding(3,4)
emb.weight = a
emb(b)

你可能感兴趣的:(pytorch,torch,gather)