torch.gather(input,dim,index)
先从二维开始:
>>> a = torch.arange(9).reshape(3,3)
>>> a
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> index= torch.LongTensor([[2,1,0]]) # 这里index和input必须有一样的维度...
>>> torch.gather(a,0,index)
tensor([[6, 4, 2]])
这里dim的意义是在哪个维度上去取值,这意味着:对index和input来说,除了这个维度以外,其他的维度必须一样。
先上一段简单的实例代码。
>>> a
tensor([[ 0.9918, 0.4911, 1.4912, -1.8491],
[ 0.1257, -0.4406, 0.3371, 0.1205],
[ 0.3064, -0.8198, 1.2851, 0.2486]])
>>> b
tensor([[0, 1],
[1, 2],
[2, 2]])
>>> a.unsqueeze(1).expand(3,2,4).gather(dim=0,index=b.unsqueeze(2).expand(3,2,4))
tensor([[[ 0.9918, 0.4911, 1.4912, -1.8491],
[ 0.1257, -0.4406, 0.3371, 0.1205]],
[[ 0.1257, -0.4406, 0.3371, 0.1205],
[ 0.3064, -0.8198, 1.2851, 0.2486]],
[[ 0.3064, -0.8198, 1.2851, 0.2486],
[ 0.3064, -0.8198, 1.2851, 0.2486]]])
上面这个例子是我自己在写一个东西的时候要用。简单描述一下需求,a是input,b是index,但是和gather定义不一样的是b选择的是一整行而不是单个元素。
以实例为例,b中的0是在a中取出一行。tf.gather(a,b)就得到了。
本来我要做的事情是把tf的代码改成pytorch,相对来说tf的代码显得简单得多,因为tf有自动的broadcast功能,而pytorch必须自己手动扩展维度,在三维的时候tf代码就很简单了。
其实这就是embedding,不知道pytorch的embedding怎么实现的,所以更加简洁优雅的写法:
emb = torch.nn.Embedding(3,4)
emb.weight = a
emb(b)