Transformer模型中位置编码的实现

import numpy as np

"""
	FGGen模型中的PosEncoding
"""
def GetPosEncodingMatrix(max_len, d_emb):
	'''
	对应文章中的使用sine和consine function去做positional encodings
	PE(pos,2i)= sin(pos/10000^2i/dmodel)
	'''
	pos_enc = np.array([
		[pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
		if pos != 0 else np.zeros(d_emb)
			for pos in range(max_len)
			])
	pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2]) # dim 2i
	pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2]) # dim 2i+1
	return pos_enc


def GetPosEncodingMatrix_myself(max_len, d_emb):
	# PE(pos,2i)= sin(pos/10000^2i/dmodel)
	pos_enc = []
	for pos in range(max_len):
		row = [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
		pos_enc.append(row)
	pos_enc = np.array(pos_enc)
	pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2])  # dim 2i
	pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2])  # dim 2i+1
	return pos_enc


if __name__ =="__main__":
	value1 = GetPosEncodingMatrix(60, 512)
	value2 = GetPosEncodingMatrix_myself(60, 512)
	equal_cnt = 0
	for i in range(value1.shape[0]):
		for j in range(value1.shape[1]):
			if value1[i][j] == value2[i][j]:
				equal_cnt += 1
			else:
				print("Error")
	print(equal_cnt)

你可能感兴趣的:(transformer,机器学习,python)