首先,先给出Transformer的MultiHeadAttention部分的pytorch版本的代码,然后再对于此部分的细节进行解析
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
"Take in model size and number of heads."
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0#剖析点1
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
# 纬度
# shape:query=key=value--->:[batch_size,max_legnth,embedding_dim=512]
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)#剖析点2
nbatches = query.size(0)
#第一步:将q,k,v分别与Wq,Wk,Wv矩阵进行相乘
#shape:Wq=Wk=Wv----->[512,512]
#第二步:将获得的Q、K、V在第三个纬度上进行切分
#shape:[batch_size,max_length,8,64]
#第三部:填充到第一个纬度
#shape:[batch_size,8,max_length,64]
query, key, value = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]#剖析点3
#进入到attention之后纬度不变,shape:[batch_size,8,max_length,64]
x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout)
# 将纬度进行还原
# 交换纬度:[batch_size,max_length,8,64]
# 纬度还原:[batch_size,max_length,512]
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d_k)#剖析点4
# 最后与WO大矩阵相乘 shape:[512,512]
return self.linears[-1](x)
assert d_model % h == 0
assert断言机制
Python assert(断言)用于判断一个表达式,在表达式条件为 false 的时候触发异常。
语法:
assert expression
等价于(这种方式比较好理解)
if not expression:
raise AssertionError(arguments)
assert 后面也可以紧跟参数:
assert expression [, arguments]
等价于
if not expression:
raise AssertionError(arguments)
eg:
assert True#没有任何输出 程序继续向下执行
assert False
#输出
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-13-a871fdc9ebee> in <module>
----> 1 assert False
AssertionError:
import sys
assert ('linux' in sys.platform), "该代码只能在 Linux 下执行"
#用于验证代码所在平台的系统是否是linux
mask = mask.unsqueeze(1)
主要是将mask进行一个升维的操作,1表示在第1个维度上升维(从0开始)
b=torch.rand(2,5)
b
#输出
tensor([[0.6956, 0.4611, 0.2149, 0.2581, 0.6836],
[0.6159, 0.4464, 0.2467, 0.2504, 0.8744]])
b.shape,b.unsqueeze(0),b.unsqueeze(0).shape#在第0个维度进行升维
#输出
(torch.Size([2, 5]), tensor([[[0.6956, 0.4611, 0.2149, 0.2581, 0.6836],
[0.6159, 0.4464, 0.2467, 0.2504, 0.8744]]]), torch.Size([1, 2, 5]))
b.shape,b.unsqueeze(1),b.unsqueeze(1).shape#在第1个维度进行升维
#输出
(torch.Size([2, 5]), tensor([[[0.6956, 0.4611, 0.2149, 0.2581, 0.6836]],
[[0.6159, 0.4464, 0.2467, 0.2504, 0.8744]]]), torch.Size([2, 1, 5]))
for l, x in zip(self.linears, (query, key, value))
作用:依次取出self.linears[0]和query,self.linears[1]和key,self.linears[2]和value 取名l和x,分别对这三对执行l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
操作
等价于
l,x=self.linears[0],query
l,x=self.linears[1],key
l,x=self.linears[2],value
对每对l,x执行:l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
举个例子:
a=[1,2,3,4]
b=torch.zeros(5,2,2)
c=torch.ones(5,2,2)
d=torch.rand(5,2,2)
print(b,b.shape)
print(c,c.shape)
print(d,d.shape)
print("============================================")
for x,y in zip(a,(b,c,d)):
print(x,y)
print("shape:",y.shape)
print("=======")
输出
tensor([[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]]) torch.Size([5, 2, 2])
tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]]) torch.Size([5, 2, 2])
tensor([[[0.0764, 0.8718],
[0.3432, 0.0081]],
[[0.8416, 0.9806],
[0.0932, 0.2501]],
[[0.7480, 0.3873],
[0.8147, 0.6484]],
[[0.6723, 0.1186],
[0.4056, 0.6158]],
[[0.6319, 0.5724],
[0.7458, 0.6811]]]) torch.Size([5, 2, 2])
============================================
1 tensor([[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]])
shape: torch.Size([5, 2, 2])
=======
2 tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
shape: torch.Size([5, 2, 2])
=======
3 tensor([[[0.0764, 0.8718],
[0.3432, 0.0081]],
[[0.8416, 0.9806],
[0.0932, 0.2501]],
[[0.7480, 0.3873],
[0.8147, 0.6484]],
[[0.6723, 0.1186],
[0.4056, 0.6158]],
[[0.6319, 0.5724],
[0.7458, 0.6811]]])
shape: torch.Size([5, 2, 2])
=======
x.transpose(1, 2).contiguous()
参考:
narrow(),view(),expand()和transpose()
举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。转置的tensor和原tensor的内存是共享的!
x = torch.randn(3, 2)
print(x,x.shape)
y = x.transpose(0, 1)
y,y.shape
#输出
tensor([[-1.9441, 1.5522],
[ 0.5396, -1.1500],
[ 1.3438, -2.3227]]) torch.Size([3, 2])
(tensor([[-1.9441, 0.5396, 1.3438],
[ 1.5522, -1.1500, -2.3227]]), torch.Size([2, 3]))
为了验证x,y是否是共享内存空间,在此我们尝试修改x矩阵的第一个元素,我们发现x修改之后,y中的数据也跟着发生改变
x[0, 0] = 111
print(y)
print(x)
#输出
tensor([[111.0000, 0.5396, 1.3438],
[ 1.5522, -1.1500, -2.3227]])
tensor([[111.0000, 1.5522],
[ 0.5396, -1.1500],
[ 1.3438, -2.3227]])
当调用contiguous()
时,会强制拷贝一份tensor,让它的布局和从头创建的一模一样
x = torch.randn(3, 2)
print(x,x.shape)
y = x.transpose(0, 1)
y=y.contiguous()
print(y,y.shape)
x[0, 0] = 111
print(y)
print(x)
#输出
tensor([[ 2.2892, -0.0997],
[-0.0294, 0.1934],
[ 0.7963, -0.3681]]) torch.Size([3, 2])
tensor([[ 2.2892, -0.0294, 0.7963],
[-0.0997, 0.1934, -0.3681]]) torch.Size([2, 3])
tensor([[ 2.2892, -0.0294, 0.7963],
[-0.0997, 0.1934, -0.3681]])
tensor([[ 1.1100e+02, -9.9713e-02],
[-2.9446e-02, 1.9339e-01],
[ 7.9626e-01, -3.6809e-01]])