每个用户抽取一定数量的困难负样本,然后ssm
def batch_softmax_loss_neg(self, user_idx, rec_user_emb, pos_idx, item_emb):
user_emb = rec_user_emb[user_idx]
product_scores = torch.matmul(F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1).transpose(0, 1))
pos_score = (rec_user_emb[user_idx] * item_emb[pos_idx]).sum(dim=-1)
pos_score = torch.exp(pos_score / self.temp2)
train_mask = self.data.ui_adj[user_idx, self.data.user_num:].toarray()
train_mask = torch.tensor(train_mask).cuda()
product_scores = product_scores * (1 - train_mask)
neg_score, indices = product_scores.topk(500, dim=1, largest=True, sorted=True)
neg_score = torch.exp(neg_score[:,400:] / self.temp2).sum(dim=-1)
loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))
return torch.mean(loss)
def batch_softmax_loss_neg(user_emb, pos_item_emb, neg_item_emb, temperature):
user_emb, pos_item_emb, neg_item_emb = F.normalize(user_emb, dim=1), F.normalize(pos_item_emb, dim=1), F.normalize(neg_item_emb, dim=1)
pos_score = (user_emb * pos_item_emb).sum(dim=-1)
pos_score = torch.exp(pos_score / temperature)
user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0],neg_item_emb.shape[1],user_emb.shape[1])
neg_score = (user_emb * neg_item_emb).sum(dim=-1) # user_emb(n*1*d) neg_item_emb = (n*m*d)
neg_score = torch.exp(neg_score / temperature).sum(dim=-1)
loss = -torch.log(pos_score / (pos_score + neg_score + 10e-6))
return torch.mean(loss)
均匀性损失(错误案例)
# def cal_uniform_loss(user_emb, item_emb):
# user_emb, item_emb = F.normalize(user_emb, dim=1), F.normalize(item_emb, dim=1)
# distance = user_emb - item_emb # n*d
# gaussian_potential = torch.exp(-2 * torch.norm(distance,p=2,dim=1))
# E_gaussian_potential = gaussian_potential.mean()
# return torch.log(E_gaussian_potential)
DNS
def DNSbpr(user_emb, pos_item_emb, neg_item_emb):
pos_score = torch.mul(user_emb, pos_item_emb).sum(dim=1)
user_emb = user_emb.unsqueeze(1).expand(user_emb.shape[0], neg_item_emb.shape[1], user_emb.shape[1])
ttl_socre = (user_emb * neg_item_emb).sum(dim=-1)
neg_score = torch.max(ttl_socre, dim=1).values
loss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))
return torch.mean(loss)
带margin的infonce
def InfoNCE_margin(view1, view2, temperature, margin, b_cos = True):
if b_cos:
view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)
pos_score = (view1 * view2).sum(dim=-1)
pos_score = torch.exp(pos_score / temperature)
margin = margin * torch.eye(view1.shape[0])
ttl_score = torch.matmul(view1, view2.transpose(0, 1))
ttl_score += margin.cuda(0)
ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
cl_loss = -torch.log(pos_score / ttl_score+10e-6)
return torch.mean(cl_loss)
def InfoNCE_tau(view1, view2, temperature):
view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)
pos_score = (view1 * view2).sum(dim=-1)
pos_score = torch.exp(pos_score / temperature)
ttl_score = torch.matmul(view1, view2.transpose(0, 1))
ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
cl_loss = -torch.log(pos_score / ttl_score+10e-6)
return torch.mean(cl_loss)
def batch_bpr_loss(user_emb, item_emb):
pos_score = torch.mul(user_emb, item_emb).sum(dim=1)
neg_score = torch.matmul(user_emb, item_emb.transpose(0, 1)).mean(dim=1)
loss = -torch.log(10e-6 + torch.sigmoid(pos_score - neg_score))
return torch.mean(loss)
def Dis_softmax(view1, view2, temperature, b_cos = True):
if b_cos:
view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)
N,M = view1.shape
pos_score = (view1 - view2).norm(p=2, dim=1)
pos_score = torch.exp(pos_score / temperature)
view1 = view1.unsqueeze(1).expand(N,N,M)
view2 = view2.unsqueeze(0).expand(N,N,M)
ttl_score = (view1 - view2).norm(p=2, dim=-1)
ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
cl_loss = torch.log(pos_score / ttl_score+10e-6)
return torch.mean(cl_loss)
LightGCN+对比学习
def forward(self, perturbed=False):
ego_embeddings = torch.cat([self.embedding_dict['user_emb'], self.embedding_dict['item_emb']], 0)
all_embeddings = []
all_embeddings_cl = ego_embeddings
for k in range(self.n_layers):
ego_embeddings = torch.sparse.mm(self.sparse_norm_adj, ego_embeddings)
if perturbed:
random_noise = torch.rand_like(ego_embeddings).cuda()
ego_embeddings += torch.sign(ego_embeddings) * F.normalize(random_noise, dim=-1) * self.eps
all_embeddings.append(ego_embeddings)
if k==self.layer_cl-1:
all_embeddings_cl += F.normalize(all_embeddings[1]-all_embeddings[0], dim=-1) * self.eps
final_embeddings = torch.stack(all_embeddings, dim=1)
final_embeddings = torch.mean(final_embeddings, dim=1)
user_all_embeddings, item_all_embeddings = torch.split(final_embeddings, [self.data.user_num, self.data.item_num])
user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embeddings_cl, [self.data.user_num, self.data.item_num])
if perturbed:
return user_all_embeddings, item_all_embeddings,user_all_embeddings_cl, item_all_embeddings_cl
return user_all_embeddings, item_all_embeddings
def train(self):
model = self.model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=self.lRate)
hot_uidx, hot_iidx = self.select_ui_idx(500, mode='hot')
cold_uidx, cold_iidx = self.select_ui_idx(500, mode='cold')
norm_uidx, norm_iidx = self.select_ui_idx(500, mode='norm')
iters = 10
alphas_init = torch.tensor([1, 2], dtype=torch.float64).to(device)
betas_init = torch.tensor([2, 1], dtype=torch.float64).to(device)
weights_init = torch.tensor([1 - 0.05, 0.05], dtype=torch.float64).to(device)
for epoch in range(self.maxEpoch):
# epoch_rec_loss = []
bmm_model = BetaMixture1D(iters, alphas_init, betas_init, weights_init)
rec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)
self.bmm_fit(rec_user_emb, rec_item_emb,torch.arange(self.data.user_num),np.random.randint(0,self.data.item_num, 100),bmm_model)
for n, batch in enumerate(next_batch_pairwise(self.data, self.batch_size)):
user_idx, pos_idx, rec_neg_idx = batch
rec_user_emb, rec_item_emb, cl_user_emb, cl_item_emb = model(True)
user_emb, pos_item_emb= rec_user_emb[user_idx], rec_item_emb[pos_idx]
# rec_loss = self.batch_softmax_loss_neg(user_idx, rec_user_emb, pos_idx, rec_item_emb)
# rec_neg_idx = torch.tensor(rec_neg_idx,dtype=torch.int64)
# rec_neg_item_emb = rec_item_emb[rec_neg_idx]
weight = self.getWeightSim(user_emb, pos_item_emb, bmm_model)
rec_loss = weighted_SSM(user_emb,pos_item_emb,self.temp2,weight)
cl_loss = self.cl_rate * self.cal_cl_loss([user_idx,pos_idx],rec_user_emb,cl_user_emb,rec_item_emb,cl_item_emb)
batch_loss = rec_loss + l2_reg_loss(self.reg, user_emb, pos_item_emb) + cl_loss
# epoch_rec_loss.append(rec_loss.item()), epoch_cl_loss.append(cl_loss.item())
# Backward and optimize
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
if n % 100==0 and n>0:
print('training:', epoch + 1, 'batch', n, 'rec_loss:', rec_loss.item(), 'cl_loss', cl_loss.item())
with torch.no_grad():
self.user_emb, self.item_emb = self.model()
hot_emb = torch.cat([self.user_emb[hot_uidx],self.item_emb[hot_iidx]],0)
cold_emb = torch.cat([self.user_emb[cold_uidx],self.item_emb[cold_iidx]],0)
self.eval_uniform(epoch, hot_emb, cold_emb)
hot_user_mag = self.cal_sim(epoch, hot_uidx, self.user_emb, self.item_emb,mode='hot')
self.cal_sim(epoch, norm_uidx, self.user_emb, self.item_emb, mode='norm')
cold_user_mag= self.cal_sim(epoch, cold_uidx, self.user_emb, self.item_emb, mode='cold')
hot_item_mag = self.item_magnitude(epoch, hot_iidx, self.item_emb,mode='hot')
self.item_magnitude(epoch, norm_iidx, self.item_emb, mode='norm')
cold_item_mag = self.item_magnitude(epoch, cold_iidx, self.item_emb, mode='cold')
print('training:',epoch + 1,'U_mag_ratio:',hot_user_mag/cold_user_mag, 'I_mag_ratio:',hot_item_mag/cold_item_mag)
# self.getTopSimNeg(hot_uidx, self.user_emb,self.item_emb, 100)
# self.getTopSimNeg(norm_uidx,self.user_emb,self.item_emb, 100)
# self.getTopSimNeg(cold_uidx,self.user_emb,self.item_emb, 100)
# epoch_rec_loss = np.array(epoch_rec_loss).mean()
# self.loss.extend([epoch_rec_loss,epoch_cl_loss,hot_pair_uniform_loss.item(),random_item_uniform_loss.item()])
# if epoch%5==0:
# self.save_emb(epoch, hot_emb, mode='hot')
# self.save_emb(epoch, random_emb, mode='random')
self.fast_evaluation(epoch)
# self.save_loss()
self.user_emb, self.item_emb = self.best_user_emb, self.best_item_emb
# self.save_emb(self.bestPerformance[0], hot_emb, mode='best_hot')
# self.save_emb(self.bestPerformance[0], random_emb, mode='best_random')