关于transformer-xl中rel-shift实现的解读


关于transformer-xl中rel-shift实现的解读_第1张图片

关于transformer-xl中rel-shift实现的解读_第2张图片 

方法

抽象地看,我们要做的事情就是,给定一个矩阵,每行都进行左移,而移动的个数随行数递增而递减。

我目前想到的一种方法是使用gather,将想要的index提前定好,然后使用Pytorch的gather就能够实现。

而transformer-xl实现了另一种更好的方法:_rel_shift

def _rel_shift(self, x, zero_triu=False):
    # x: q,k,bs,n_head
    zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                           device=x.device, dtype=x.dtype)
    x_padded = torch.cat([zero_pad, x], dim=1)

    x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

    x = x_padded[1:].view_as(x)

    return x

第一步是,将x的第一列填上padding,此时x.size()=q,k+1,bs,n_head,接下来将其重新reshape,则变成了x.size()=k+1,q,bs,n_head,最后将第一行去掉,变成x.size()=k,q,bs,n_head,再将其reshape回x原来的样子。

为什么这么做实现了我们想要的左移的功能?我们应该从一维的角度去理解。因为实际上在内存中所有元素都是按照一维去排列的。

原来的矩阵:
关于transformer-xl中rel-shift实现的解读_第3张图片

实际上就是有q个key按照一行去排列。

在做完padding之后,则:
关于transformer-xl中rel-shift实现的解读_第4张图片

实际上就是在每个key前面插入了0。

接下来view,实际上数据的先后顺序还是没有变(因为不是transpose):
关于transformer-xl中rel-shift实现的解读_第5张图片

实际上只是强行将该行切成一个一个q而已。

那么最后一个操作,将第一行丢掉,实际上就是要把原来的x的第一行强行左移q-1个(因为有padding)。那么为什么后面的行能够左移的个数依次减少?别忘了padding,第一行左移了q-1个,但第二个key前面也有一个padding,所以相当于将其向右推了一格;第三个又有一个padding,就在原来的基础上又推了一格,也即推了两格。因此最后达到了我们想要的目的。

实际上要理解该方法,需要牢牢把握数据存储的本质是一整行。

该方法没有数据的拷贝,全部都是view操作,因此更高效。

不得不佩服想到该方法的人的工程能力,同时也感谢戴宁带我理解该方法的本质,一开始我是死活不理解的。以后或许可以将该思想灵活应用到其他方面。

  • 本文作者: 林泽辉
  • 本文链接: http://www.linzehui.me/2019/05/07/代码相关/关于transformer-xl中rel-shift实现的解读/
  • 版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 3.0 许可协议。转载请注明出处!

你可能感兴趣的:(nlp,bert)