transe 简单代码实现

用于对知识图谱中的实体、关系基于TransE算法训练获取向量
结果为:两个文本文件,即entityVector.txt和relationVector.txt
但是数据集没办法上传,如果有需要联系我哦。

# -*- coding: utf-8 -*-
"""
@description: 增加了对代码的一些注解
"""
from random import uniform, sample
from numpy import *
from copy import deepcopy

class TransE:
    def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):
        """
        目标函数的常数——margin
        学习率——learningRate
        向量维度——dim
        实体列表——entityList(读取文本文件,实体+id)
        关系列表——relationList(读取文本文件,关系 + id)
        三元关系列表——tripleList(读取文本文件,实体 + 实体 + 关系)
        损失值——loss
        距离公式——L1        
        """
        self.margin = margin
        self.learingRate = learingRate
        self.dim = dim#向量维度
        self.entityList = entityList#一开始,entityList是entity的list;
        # 初始化后,变为字典,key是entity,values是其向量(使用narray)。
        self.relationList = relationList#理由同上
        self.tripleList = tripleList#理由同上
        self.loss = 0
        self.L1 = L1

    def initialize(self):
        '''
        初始化向量
        '''
        entityVectorList = {}
        relationVectorList = {}
        for entity in self.entityList:  # 对entityList进行遍历
            n = 0
            entityVector = []
            while n < self.dim:
                ram = init(self.dim)  #调用init函数,返回一个实数类似1.3266
                entityVector.append(ram)   # 将ram 添加到实体向量中
                n += 1
            entityVector = norm(entityVector)  #调用norm函数,单位化
            entityVectorList[entity] = entityVector
        print("entityVector初始化完成,数量是%d"%len(entityVectorList))
        for relation in self. relationList:
            n = 0
            relationVector = []
            while n < self.dim:   # 循环dim次
                ram = init(self.dim)   #调用init函数,返回一个实数类似1.3266
                relationVector.append(ram)   # 将ram 添加到关系向量中
                n += 1
            relationVector = norm(relationVector)  #归一化
            relationVectorList[relation] = relationVector
        print("relationVectorList初始化完成,数量是%d"%len(relationVectorList))
        self.entityList = entityVectorList
        self.relationList = relationVectorList

    def transE(self, cI = 20):
        print("训练开始")
        for cycleIndex in range(cI):
            Sbatch = self.getSample(150)    #随机选取150个元素
            Tbatch = []     # 初始空 元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}
            for sbatch in Sbatch:
                tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch))   #{((h,r,t),(h',r,t'))}
                if(tripletWithCorruptedTriplet not in Tbatch):
                    Tbatch.append(tripletWithCorruptedTriplet)
            self.update(Tbatch)
            if cycleIndex % 100 == 0:
                print("第%d次循环"%cycleIndex)
                print(self.loss)

                self.writeRelationVector(r"F:\pycharm的项目\transe\data\FB15k\relationVector10.txt")
                self.writeEntilyVector(r"F:\pycharm的项目\transe\data\FB15k\entityVector10.txt")

                # self.writeRelationVector("f:\\relationVector.txt")
                # self.writeEntilyVector("f:\\entityVector.txt")
                self.loss = 0

    def getSample(self, size):
        #—随机选取部分三元关系,Sbatch
        # sample(序列a,n)
        # 功能:从序列a中随机抽取n个元素,并将n个元素生以list形式返回。
        return sample(self.tripleList, size)

    def getCorruptedTriplet(self, triplet):
        '''
        training triplets with either the head or tail replaced by a random entity (but not both at the same time)
         #随机替换三元组的实体,h、t中任意一个被替换,但不同时替换。
        :param triplet:
        :return corruptedTriplet:
        '''
        i = uniform(-1, 1)  #uniform(a, b)#随机生成a,b之间的数,左闭右开。
        if i < 0:#小于0,打坏三元组的第一项
            while True:
                entityTemp = sample(self.entityList.keys(), 1)[0] #从entityList.key()中sample一个元素,以列表行驶返回第一个元素
                if entityTemp != triplet[0]:
                    break
            corruptedTriplet = (entityTemp, triplet[1], triplet[2])
        else:#大于等于0,打坏三元组的第二项
            while True:
                entityTemp = sample(self.entityList.keys(), 1)[0]
                if entityTemp != triplet[1]:
                    break
            corruptedTriplet = (triplet[0], entityTemp, triplet[2])
        return corruptedTriplet

    def update(self, Tbatch):
        copyEntityList = deepcopy(self.entityList) # 深拷贝  作为一个独立的存在 不会改变原来的值
        copyRelationList = deepcopy(self.relationList)
        
        for tripletWithCorruptedTriplet in Tbatch:
            # [((h,t,r),(h',t',r)),(())]
            headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]
            #tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
            tailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]
            relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]
            headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]
            tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]
            
            headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]
            #tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
            tailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]
            relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]
            headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]
            tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]
            
            if self.L1:
                # 计算正常情况下的误差
                distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
                # 计算随机情况下的误差
                distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch ,  relationVectorBeforeBatch)
            else:
                distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
                distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch ,  relationVectorBeforeBatch)
            # margin loss = max(0, margin + pos - neg)
            eg = self.margin + distTriplet - distCorruptedTriplet
            if eg > 0: #[function]+ 是一个取正值的函数
                self.loss += eg
                if self.L1:
                    # tempos = 2 * lr * (t - h - r)
                    tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
                    tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
                    tempPositiveL1 = []
                    tempNegtativeL1 = []
                    for i in range(self.dim):#不知道有没有pythonic的写法(比如列表推倒或者numpy的函数)?
                        if tempPositive[i] >= 0:
                            tempPositiveL1.append(1)
                        else:
                            tempPositiveL1.append(-1)
                        if tempNegtative[i] >= 0:
                            tempNegtativeL1.append(1)
                        else:
                            tempNegtativeL1.append(-1)
                    tempPositive = array(tempPositiveL1)  
                    tempNegtative = array(tempNegtativeL1)

                else:
                    tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
                    tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
    
                headEntityVector = headEntityVector + tempPositive
                tailEntityVector = tailEntityVector - tempPositive
                relationVector = relationVector + tempPositive - tempNegtative
                headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative
                tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative

                #只归一化这几个刚更新的向量,而不是按原论文那些一口气全更新了
                copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)
                copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)
                copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)
                copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)
                copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)
                
        self.entityList = copyEntityList
        self.relationList = copyRelationList
        
    def writeEntilyVector(self, dir):
        print("写入实体")
        entityVectorFile = open(dir, 'w')
        for entity in self.entityList.keys():
            entityVectorFile.write(entity+"\t")
            entityVectorFile.write(str(self.entityList[entity].tolist()))
            entityVectorFile.write("\n")
        entityVectorFile.close()

    def writeRelationVector(self, dir):
        print("写入关系")
        relationVectorFile = open(dir, 'w')
        for relation in self.relationList.keys():
            relationVectorFile.write(relation + "\t")
            relationVectorFile.write(str(self.relationList[relation].tolist()))
            relationVectorFile.write("\n")
        relationVectorFile.close()

def init(dim):
    # uniform() 方法将随机生成下一个实数,它在[x, y]范围内。
    return uniform(-6/(dim**0.5), 6/(dim**0.5))

def distanceL1(h, t ,r):
    """
    trans e
    :param h:  head embendding
    :param t:   tail 
    :param r:  relation
    :return: 返回绝对误差和
    """
    s = h + r - t
    sum = fabs(s).sum()  # fabs() 方法返回数字的绝对值,如math.fabs(-10) 返回10.0。
    return sum

def distanceL2(h, t, r):
    """
    trans r
    :param h: 
    :param t: 
    :param r: 
    :return: 返回误差平方和
    """
    s = h + r - t
    sum = (s*s).sum()
    return sum
 
def norm(list):
    '''
    归一化
    :param 向量
    :return: 返回元素除以平方和后的数组
    '''
    var = linalg.norm(list)
    #x_norm=np.linalg.norm(x, ord=None, axis=None, keepdims=False)
    # 求范数  默认情况下,是求整体的矩阵元素平方和,再开根号。
    i = 0
    while i < len(list):
        list[i] = list[i]/var   #list中每一元素/var
        i += 1
    return array(list)

def openDetailsAndId(dir,sp="\t"):
    """
    :param dir: 路径  文件内容 皆为 /m/06rf7  0  其中entity 14951个
    :param sp: 
    :return: 返回idNum,名字列表
    """
    idNum = 0
    list = []
    with open(dir) as file:
        lines = file.readlines()  # 读取文件所有行
        for line in lines:    # 一行一行
            DetailsAndId = line.strip().split(sp)
            #strip(str)只能删除开头或是结尾的字符或是字符串
            # split(str) 按str分割 返回的是一个列表
            list.append(DetailsAndId[0])
            # 将名字添加到list
            idNum += 1
    return idNum, list

def openTrain(dir,sp="\t"):
    """
    /m/027rn   /m/06cx9   /location/country/form_of_government
    :param dir: 
    :param sp: 
    :return: 返回num 和关系总列表
    """
    num = 0
    list = []
    with open(dir) as file:
        lines = file.readlines()
        for line in lines:
            triple = line.strip().split(sp)
            if(len(triple)<3):  # 如果triple内没有三个元素,则结束本次循环
                continue
            list.append(tuple(triple))  # 将返回的三元列表 添加到list列表中
            num += 1
    return num, list

if __name__ == '__main__':
    dirEntity = r"F:\pycharm的项目\transe\data\FB15k\entity2id.txt"

    entityIdNum, entityList = openDetailsAndId(dirEntity)
    dirRelation = r"F:\pycharm的项目\transe\data\FB15k\relation2id.txt"

    relationIdNum, relationList = openDetailsAndId(dirRelation)
    dirTrain = r"F:\pycharm的项目\transe\data\FB15k\train.txt"

    tripleNum, tripleList = openTrain(dirTrain)
    print("打开TransE")
    # 在这里调用transE函数时,dim可以重新传参,
    transE = TransE(entityList,relationList,tripleList, margin=1, dim = 10)
    print("TranE初始化")
    transE.initialize()
    transE.transE(15000)

    transE.writeRelationVector(r"F:\pycharm的项目\transe\data\FB15k\relationVector10.txt")
    transE.writeEntilyVector(r"F:\pycharm的项目\transe\data\FB15k\entityVector10.txt")

你可能感兴趣的:(python)