1.数据集准备
import openke
from openke.data import TrainDataLoader,TestDataLoader
train_dataloader = TrainDataLoader(
in_path = "./benchmarks/FB15K237_tiny/",
nbatches = 100,
threads = 8,
sampling_mode = 'normal',
bern_flag = 1,
neg_ent = 25,
neg_rel = 0
)
test_dataloader = TestDataLoader("./benchmarks/FB15K237_tiny/", "link")
2.模型构建
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransE(nn.Module):
def __init__(self,ent_tot,rel_tot,dim = 100,p_norm = 1,norm_flag = True,margin = None,epsilon = None):
super(TransE,self).__init__()
self.ent_tot = ent_tot
self.rel_tot = rel_tot
self.dim = dim
self.margin = margin
self.epsilon = epsilon
self.norm_flag = norm_flag
self.p_norm = p_norm
self.ent_embeddings = nn.Embedding(self.ent_tot,self.dim)
self.rel_embeddings = nn.Embedding(self.rel_tot,self.dim)
if margin == None or epsilon == None:
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
self.embedding_range = nn.Parameter(
torch.Tensor([(self.margin + self.epsilon) / self.dim]),requires_grad = False
)
nn.init.uniform_(
tensor = self.ent_embeddings.weight.data,
a = -self.embedding_range.item(),
b = self.embedding_range.item()
)
nn.init.uniform_(
tensor = self.rel_embeddings.weight.data,
a = -self.embedding_range.item(),
b = self.embedding_range.item()
)
if margin != None:
self.margin = nn.Parameter(torch.Tensor([margin]))
self.margin.requires_grad = False
self.margin_flag = True
else:
self.margin_flag = False
def _calc(self,h,t,r,mode):
if self.norm_flag:
h = F.normalize(h,2,-1)
r = F.normalize(r,2,-1)
t = F.normalize(t,2,-1)
if mode != 'normal':
h = h.view(-1,r.shape[0],h.shape[-1])
t = t.view(-1,r.shape[0],t.shape[-1])
r = r.view(-1,r.shape[0],r.shape[-1])
if mode == 'head_batch':
score = h + (r - t)
else:
score = (h + r) - t
score = torch.norm(score,self.p_norm,-1).flatten()
return score
def forward(self,data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
mode = data['mode']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
score = self._calc(h,t,r,mode)
if self.margin_flag:
return self.margin - score
else:
return score
def regularization(self,data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.ent_embeddings(batch_r)
regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2)) / 3
return regul
def predict(self,data):
score = self.forward(data)
if self.margin_flag:
score = self.margin - score
return score.cpu().data.numpy()
else:
return score.cpu().data.numpy()
def save_checkpoint(self, path):
torch.save(self.state_dict(), path)
'''
model = TransE(train_dataloader.get_ent_tot(),train_dataloader.get_rel_tot())
example = list(train_dataloader)[0]
for key in example:
if type(example[key]) != str:
example[key] = torch.LongTensor(example[key])
model(example)
'''
pass
3.损失函数构建
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class MarginLoss(nn.Module):
def __init__(self,adv_temperature = None,margin = 6.0):
super(MarginLoss,self).__init__()
self.margin = nn.Parameter(torch.Tensor([margin]))
self.margin.requires_grad = False
if adv_temperature != None:
self.adv_temperature = nn.Parameter(torch.Tensor([adv_temperature]))
self.adv_temperatrue.requires_grad = False
self.adv_flag = True
else:
self.adv_flag = False
def get_weights(self,n_score):
return F.softmax(-n_score * self.adv_temperature,dim = -1).detach()
def forward(self,p_score,n_score):
if self.adv_flag:
return (self.get_weights(n_score) * torch.max(p_score - n_score,self.margin)).sum(dim = -1).mean() + self.margin
else:
return (torch.max(p_score - n_score,-self.margin)).mean() + self.margin
def predict(self,p_score,n_score):
score = self.forward(p_score,n_score)
return score.cpu().data.numpy()
'''
model = TransE(train_dataloader.get_ent_tot(),train_dataloader.get_rel_tot())
example = list(train_dataloader)[0]
for key in example:
if type(example[key]) != str:
example[key] = torch.LongTensor(example[key])
loss = MarginLoss()
negativeSampling = NegativeSampling(model,loss,batch_size=train_dataloader.batch_size)
negativeSampling(example)
'''
pass
4.负采样构建
class NegativeSampling(nn.Module):
def __init__(self,model = None,loss = None,batch_size = 256,regul_rate = 0.0,l3_regul_rate = 0.0):
super(NegativeSampling,self).__init__()
self.model = model
self.loss = loss
self.batch_size = batch_size
self.regul_rate = regul_rate
self.l3_regul_rate = l3_regul_rate
def _get_positive_score(self,score):
positive_score = score[:self.batch_size]
positive_score = positive_score.view(-1,self.batch_size).permute(1,0)
return positive_score
def _get_negative_score(self,score):
negative_score = score[self.batch_size:]
negative_score = negative_score.view(-1,self.batch_size).permute(1,0)
return negative_score
def forward(self,data):
score = self.model(data)
p_score = self._get_positive_score(score)
n_score = self._get_negative_score(score)
loss_res = self.loss(p_score,n_score)
if self.regul_rate != 0:
loss_res += self.regul_rate * self.model.regularization(data)
if self.l3_regul_rate != 0:
loss_res += self.l3_regul_rate * self.model.l3_regularization()
return loss_res
def save_checkpoint(self, path):
torch.save(self.state_dict(), path)
5.模型训练
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import os
import time
import sys
import datetime
import ctypes
import json
import numpy as np
import copy
from tqdm import tqdm
class Traniner(object):
def __init__(self,model = None,
data_loader = None,
train_times = 1000,
alpha = 0.5,
use_gpu = True,
opt_method = 'sgd',
save_steps = None,
checkpoint_dir = None
):
self.work_threads = 8
self.train_times = train_times
self.opt_method = opt_method
self.optimizer = None
self.weight_decay = 0
self.alpha = alpha
self.model = model
self.data_loader = data_loader
self.use_gpu = use_gpu
self.save_steps = save_steps
self.checkpoint_dir = checkpoint_dir
def to_var(self,x,use_gpu):
if use_gpu:
return Variable(torch.from_numpy(x).cuda())
else:
return Variable(torch.from_numpy(x))
def train_one_step(self,data):
self.optimizer.zero_grad()
loss = self.model({
'batch_h':self.to_var(data['batch_h'],self.use_gpu),
'batch_t':self.to_var(data['batch_t'],self.use_gpu),
'batch_r':self.to_var(data['batch_r'],self.use_gpu),
'batch_y':self.to_var(data['batch_y'],self.use_gpu),
'mode':data['mode']
})
loss.backward()
self.optimizer.step()
return loss.item()
def run(self):
if self.use_gpu:
self.model.cuda()
if self.optimizer != None:
pass
elif self.opt_method == 'Adagrad' or self.opt_method == 'adagrad':
self.optimizer = optim.Adagrad(
self.model.parameters(),
lr = self.alpha,
lr_decay = self.lr_decay,
weight_decay = self.weight_decay
)
elif self.opt_method == 'Adadelta' or self.opt_method == 'adadelta':
self.optimizer = optim.Adadelta(
self.model.parameters(),
lr = self.alpha,
weight_decay = self.weight_decay
)
elif self.opt_method == 'Adam' or self.opt_method == 'adam':
self.optimizer = optim.Adam(
self.model.parameters(),
lr = self.alpha,
weight_decay = self.weight_decay
)
else:
self.optimizer = optim.SGD(
self.model.parameters(),
lr = self.alpha,
weight_decay = self.weight_decay
)
print('Finish initializing...')
training_range = tqdm(range(self.train_times))
for epoch in training_range:
res = 0.0
for data in self.data_loader:
loss = self.train_one_step(data)
res += loss
training_range.set_description('Epoch %d | loss: %f' % (epoch,res))
if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0:
print('Epoch %d has finished,saving...' % (epoch))
self.model.save_checkpoint(os.path.join(self.checkpoint_dir + '-' + str(epoch) + '.ckpt'))
transe = TransE(train_dataloader.get_ent_tot(),train_dataloader.get_rel_tot())
train_dataloader = TrainDataLoader(
in_path = "./benchmarks/FB15K237_tiny/",
nbatches = 100,
threads = 8,
sampling_mode = 'normal',
bern_flag = 1,
neg_ent = 25,
neg_rel = 0
)
test_dataloader = TestDataLoader("./benchmarks/FB15K237_tiny/", "link")
loss = MarginLoss()
model = NegativeSampling(transe,loss,batch_size=train_dataloader.batch_size)
trainer = Traniner(model = model,data_loader = train_dataloader)
trainer.run()