Transformer算法理解(1)-位置编码

Transformer算法理解(1)-位置编码_第1张图片

学习Transformer按照这个图从左到右,从encoder 到decoder逐层解析,首先一个序列输入inputs经过embeding词嵌入之后,加上了Positional Encoding。这是因为Transformer完全基于self-attention机制,不同于RNN,模型并没有捕捉顺序序列的能力,也就是说无论句子的结构怎么打乱,Transformer都会得到类似的结果。
为了解决这个问题引入了位置编码。作者提供了两种思路:

  • 通过训练学习 positional encoding 向量;
  • 使用公式来计算 positional encoding向量。

试验后发现两种选择的结果是相似的,所以采用了第2种方法,优点是不需要训练参数,而且即使在训练集中没有出现过的句子长度上也能用。

论文给出的公式如下:
Transformer算法理解(1)-位置编码_第2张图片
这么设计的原因是考虑到在NLP任务中,除了单词的绝对位置,单词的相对位置也非常重要。选择正弦曲线函数,是因为位置 k + p k +p k+p的编码可以表示为位置 k k k的编码的线性变化。
至于为啥能推导,可以参考这篇博文,

代码实现:

def get_sinusoid_table(self, seq_len, d_model):
	def get_angle(pos, i, d_model):
		return pos / np.power(10000, (2 * (i//2)) / d_model)
	
	sinusoid_table = np.zeros((seq_len, d_model))
	for pos in range(seq_len):
		for i in range(d_model):
			if i%2 == 0:
				sinusoid_table[pos, i] = np.sin(get_angle(pos, i, d_model))
			else:
				sinusoid_table[pos, i] = np.cos(get_angle(pos, i, d_model))

	return torch.FloatTensor(sinusoid_table)

因此embeding+pos_embeding的整体逻辑如下:

def __init__(self, vocab_size, seq_len, d_model=512, n_layers=6, n_heads=8, p_drop=0.1, d_ff=2048, pad_id=0):
	super(TransformerEncoder, self).__init__()
	self.pad_id = pad_id
	self.sinusoid_table = self.get_sinusoid_table(seq_len+1, d_model) # (seq_len+1, d_model)
	# layers
	self.embedding = nn.Embedding(vocab_size, d_model) #训练得出
	self.pos_embedding = nn.Embedding.from_pretrained(self.sinusoid_table, freeze=True) # from_pretrained 非训练得出, freeze=True表示不更新


def forward(self, inputs):
    # |inputs| : (batch_size, seq_len)
    positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).repeat(inputs.size(0), 1) + 1 #获取位置信息
    position_pad_mask = inputs.eq(self.pad_id)
    positions.masked_fill_(position_pad_mask, 0) #pad处不进行位置编码
    # |positions| : (batch_size, seq_len)

    outputs = self.embedding(inputs) + self.pos_embedding(positions)

参考资料
[1] https://zhuanlan.zhihu.com/p/106644634
[2] https://zhuanlan.zhihu.com/p/398457641
[3] https://zhuanlan.zhihu.com/p/48508221
[4] https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
[5] https://www.cnblogs.com/zingp/p/11696111.html#_label0

你可能感兴趣的:(深度学习,nlp,python,pytorch,transformer)