用于对知识图谱中的实体、关系基于TransE算法训练获取向量
结果为:两个文本文件,即entityVector.txt和relationVector.txt
但是数据集没办法上传,如果有需要联系我哦。
# -*- coding: utf-8 -*-
"""
@description: 增加了对代码的一些注解
"""
from random import uniform, sample
from numpy import *
from copy import deepcopy
class TransE:
def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):
"""
目标函数的常数——margin
学习率——learningRate
向量维度——dim
实体列表——entityList(读取文本文件,实体+id)
关系列表——relationList(读取文本文件,关系 + id)
三元关系列表——tripleList(读取文本文件,实体 + 实体 + 关系)
损失值——loss
距离公式——L1
"""
self.margin = margin
self.learingRate = learingRate
self.dim = dim#向量维度
self.entityList = entityList#一开始,entityList是entity的list;
# 初始化后,变为字典,key是entity,values是其向量(使用narray)。
self.relationList = relationList#理由同上
self.tripleList = tripleList#理由同上
self.loss = 0
self.L1 = L1
def initialize(self):
'''
初始化向量
'''
entityVectorList = {}
relationVectorList = {}
for entity in self.entityList: # 对entityList进行遍历
n = 0
entityVector = []
while n < self.dim:
ram = init(self.dim) #调用init函数,返回一个实数类似1.3266
entityVector.append(ram) # 将ram 添加到实体向量中
n += 1
entityVector = norm(entityVector) #调用norm函数,单位化
entityVectorList[entity] = entityVector
print("entityVector初始化完成,数量是%d"%len(entityVectorList))
for relation in self. relationList:
n = 0
relationVector = []
while n < self.dim: # 循环dim次
ram = init(self.dim) #调用init函数,返回一个实数类似1.3266
relationVector.append(ram) # 将ram 添加到关系向量中
n += 1
relationVector = norm(relationVector) #归一化
relationVectorList[relation] = relationVector
print("relationVectorList初始化完成,数量是%d"%len(relationVectorList))
self.entityList = entityVectorList
self.relationList = relationVectorList
def transE(self, cI = 20):
print("训练开始")
for cycleIndex in range(cI):
Sbatch = self.getSample(150) #随机选取150个元素
Tbatch = [] # 初始空 元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}
for sbatch in Sbatch:
tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch)) #{((h,r,t),(h',r,t'))}
if(tripletWithCorruptedTriplet not in Tbatch):
Tbatch.append(tripletWithCorruptedTriplet)
self.update(Tbatch)
if cycleIndex % 100 == 0:
print("第%d次循环"%cycleIndex)
print(self.loss)
self.writeRelationVector(r"F:\pycharm的项目\transe\data\FB15k\relationVector10.txt")
self.writeEntilyVector(r"F:\pycharm的项目\transe\data\FB15k\entityVector10.txt")
# self.writeRelationVector("f:\\relationVector.txt")
# self.writeEntilyVector("f:\\entityVector.txt")
self.loss = 0
def getSample(self, size):
#—随机选取部分三元关系,Sbatch
# sample(序列a,n)
# 功能:从序列a中随机抽取n个元素,并将n个元素生以list形式返回。
return sample(self.tripleList, size)
def getCorruptedTriplet(self, triplet):
'''
training triplets with either the head or tail replaced by a random entity (but not both at the same time)
#随机替换三元组的实体,h、t中任意一个被替换,但不同时替换。
:param triplet:
:return corruptedTriplet:
'''
i = uniform(-1, 1) #uniform(a, b)#随机生成a,b之间的数,左闭右开。
if i < 0:#小于0,打坏三元组的第一项
while True:
entityTemp = sample(self.entityList.keys(), 1)[0] #从entityList.key()中sample一个元素,以列表行驶返回第一个元素
if entityTemp != triplet[0]:
break
corruptedTriplet = (entityTemp, triplet[1], triplet[2])
else:#大于等于0,打坏三元组的第二项
while True:
entityTemp = sample(self.entityList.keys(), 1)[0]
if entityTemp != triplet[1]:
break
corruptedTriplet = (triplet[0], entityTemp, triplet[2])
return corruptedTriplet
def update(self, Tbatch):
copyEntityList = deepcopy(self.entityList) # 深拷贝 作为一个独立的存在 不会改变原来的值
copyRelationList = deepcopy(self.relationList)
for tripletWithCorruptedTriplet in Tbatch:
# [((h,t,r),(h',t',r)),(())]
headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]
#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
tailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]
relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]
headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]
tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]
headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]
#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
tailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]
relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]
headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]
tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]
if self.L1:
# 计算正常情况下的误差
distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
# 计算随机情况下的误差
distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)
else:
distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)
# margin loss = max(0, margin + pos - neg)
eg = self.margin + distTriplet - distCorruptedTriplet
if eg > 0: #[function]+ 是一个取正值的函数
self.loss += eg
if self.L1:
# tempos = 2 * lr * (t - h - r)
tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
tempPositiveL1 = []
tempNegtativeL1 = []
for i in range(self.dim):#不知道有没有pythonic的写法(比如列表推倒或者numpy的函数)?
if tempPositive[i] >= 0:
tempPositiveL1.append(1)
else:
tempPositiveL1.append(-1)
if tempNegtative[i] >= 0:
tempNegtativeL1.append(1)
else:
tempNegtativeL1.append(-1)
tempPositive = array(tempPositiveL1)
tempNegtative = array(tempNegtativeL1)
else:
tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
headEntityVector = headEntityVector + tempPositive
tailEntityVector = tailEntityVector - tempPositive
relationVector = relationVector + tempPositive - tempNegtative
headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative
tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative
#只归一化这几个刚更新的向量,而不是按原论文那些一口气全更新了
copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)
copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)
copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)
copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)
copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)
self.entityList = copyEntityList
self.relationList = copyRelationList
def writeEntilyVector(self, dir):
print("写入实体")
entityVectorFile = open(dir, 'w')
for entity in self.entityList.keys():
entityVectorFile.write(entity+"\t")
entityVectorFile.write(str(self.entityList[entity].tolist()))
entityVectorFile.write("\n")
entityVectorFile.close()
def writeRelationVector(self, dir):
print("写入关系")
relationVectorFile = open(dir, 'w')
for relation in self.relationList.keys():
relationVectorFile.write(relation + "\t")
relationVectorFile.write(str(self.relationList[relation].tolist()))
relationVectorFile.write("\n")
relationVectorFile.close()
def init(dim):
# uniform() 方法将随机生成下一个实数,它在[x, y]范围内。
return uniform(-6/(dim**0.5), 6/(dim**0.5))
def distanceL1(h, t ,r):
"""
trans e
:param h: head embendding
:param t: tail
:param r: relation
:return: 返回绝对误差和
"""
s = h + r - t
sum = fabs(s).sum() # fabs() 方法返回数字的绝对值,如math.fabs(-10) 返回10.0。
return sum
def distanceL2(h, t, r):
"""
trans r
:param h:
:param t:
:param r:
:return: 返回误差平方和
"""
s = h + r - t
sum = (s*s).sum()
return sum
def norm(list):
'''
归一化
:param 向量
:return: 返回元素除以平方和后的数组
'''
var = linalg.norm(list)
#x_norm=np.linalg.norm(x, ord=None, axis=None, keepdims=False)
# 求范数 默认情况下,是求整体的矩阵元素平方和,再开根号。
i = 0
while i < len(list):
list[i] = list[i]/var #list中每一元素/var
i += 1
return array(list)
def openDetailsAndId(dir,sp="\t"):
"""
:param dir: 路径 文件内容 皆为 /m/06rf7 0 其中entity 14951个
:param sp:
:return: 返回idNum,名字列表
"""
idNum = 0
list = []
with open(dir) as file:
lines = file.readlines() # 读取文件所有行
for line in lines: # 一行一行
DetailsAndId = line.strip().split(sp)
#strip(str)只能删除开头或是结尾的字符或是字符串
# split(str) 按str分割 返回的是一个列表
list.append(DetailsAndId[0])
# 将名字添加到list
idNum += 1
return idNum, list
def openTrain(dir,sp="\t"):
"""
/m/027rn /m/06cx9 /location/country/form_of_government
:param dir:
:param sp:
:return: 返回num 和关系总列表
"""
num = 0
list = []
with open(dir) as file:
lines = file.readlines()
for line in lines:
triple = line.strip().split(sp)
if(len(triple)<3): # 如果triple内没有三个元素,则结束本次循环
continue
list.append(tuple(triple)) # 将返回的三元列表 添加到list列表中
num += 1
return num, list
if __name__ == '__main__':
dirEntity = r"F:\pycharm的项目\transe\data\FB15k\entity2id.txt"
entityIdNum, entityList = openDetailsAndId(dirEntity)
dirRelation = r"F:\pycharm的项目\transe\data\FB15k\relation2id.txt"
relationIdNum, relationList = openDetailsAndId(dirRelation)
dirTrain = r"F:\pycharm的项目\transe\data\FB15k\train.txt"
tripleNum, tripleList = openTrain(dirTrain)
print("打开TransE")
# 在这里调用transE函数时,dim可以重新传参,
transE = TransE(entityList,relationList,tripleList, margin=1, dim = 10)
print("TranE初始化")
transE.initialize()
transE.transE(15000)
transE.writeRelationVector(r"F:\pycharm的项目\transe\data\FB15k\relationVector10.txt")
transE.writeEntilyVector(r"F:\pycharm的项目\transe\data\FB15k\entityVector10.txt")