推荐系统中的各个模型对比学习损失InfoNCE的具体实现方法

1、SGL写法(Self-supervised Graph Learning for Recommendation)
SGL使用的是基于图结构扰动的数据增强方式,他给每一个节点都建立了augmented views。作者认为同一个节点增强出来的views看作positive pairs { ( z u ′ , z u ′ ′ ) ∣ u ∈ U } , \left \{ \left ( z_{u}^{'},z_{u}^{''} \right ) |u\in U \right \} , {(zu,zu′′)uU}同时任何不同节点的产生的views当作negative pairs { ( z u ′ , z v ′ ′ ) ∣ u , v ∈ U , u ≠ v } 。 \left \{ \left ( z_{u}^{'},z_{v}^{''} \right ) |u,v\in U ,u \ne v \right \} 。 {(zu,zv′′)u,vU,u=v}作者是借鉴的SimCLR的论文,采用了InfoNCE的loss来做。
The auxiliary supervision of positive pairs encourages the consistency between different views of the same node for prediction, while the supervision of negative pairs enforces the divergence among different nodes.
L s s l u s e r = ∑ u ∈ U − l o g exp ⁡ ( s ( z u ′ , z u ′ ′ ) / τ ) ∑ u ∈ U exp ⁡ ( s ( z u ′ , z v ′ ′ ) / τ ) L_{ssl}^{user}=\sum_{u\in U}-log \frac{\exp \left ( s\left ( z_{u}^{'},z_{u}^{''} \right )/\tau \right ) }{ {\textstyle \sum_{u\in U}\exp \left ( s\left ( z_{u}^{'},z_{v}^{''} \right ) /\tau \right ) } } Lssluser=uUloguUexp(s(zu,zv′′)/τ)exp(s(zu,zu′′)/τ)
其中, s ( ⋅ ) , s\left ( \cdot \right ) , s(),测量两个向量之间的相似性,作者用了余弦相似度函数, τ , \tau, τ表示the temperature in softmax。同理可得 i t e m item item端的对比损失,然后自监督任务就可用下式表示: L s s l = L s s l u s e r + L s s l i t e m 。 L_{ssl}=L_{ssl}^{user}+L_{ssl}^{item}。 Lssl=Lssluser+Lsslitem

# LightGCN前向传播、卷积部分、
    def forward(self, sub_graph1, sub_graph2, users, items, neg_items):
        user_embeddings, item_embeddings = self._forward_gcn(self.norm_adj)
        user_embeddings1, item_embeddings1 = self._forward_gcn(sub_graph1)
        user_embeddings2, item_embeddings2 = self._forward_gcn(sub_graph2)

        # Normalize embeddings learnt from sub-graph to construct SSL loss
        user_embeddings1 = F.normalize(user_embeddings1, dim=1)
        item_embeddings1 = F.normalize(item_embeddings1, dim=1)
        user_embeddings2 = F.normalize(user_embeddings2, dim=1)
        item_embeddings2 = F.normalize(item_embeddings2, dim=1)
        # 先对表征进行2-范式正则化


        user_embs = F.embedding(users, user_embeddings)
        item_embs = F.embedding(items, item_embeddings)
        neg_item_embs = F.embedding(neg_items, item_embeddings)
        user_embs1 = F.embedding(users, user_embeddings1)
        item_embs1 = F.embedding(items, item_embeddings1)
        user_embs2 = F.embedding(users, user_embeddings2)
        item_embs2 = F.embedding(items, item_embeddings2)
        # 查表获取各个部分的embedding信息


		# 接下来开始计算各种loss
        sup_pos_ratings = inner_product(user_embs, item_embs)       # [batch_size]
        sup_neg_ratings = inner_product(user_embs, neg_item_embs)   # [batch_size]
        sup_logits = sup_pos_ratings - sup_neg_ratings              # [batch_size]
        # BPR_LOSS首当其冲,作为模型的main loss是相当重要的,模型优化

		# 接下来就是对比损失的计算了,分别针对user端和item端进行计算。
        tot_ratings_user = torch.matmul(user_embs1,torch.transpose(user_embeddings2, 0, 1)) # [batch_size,num_users]
        pos_ratings_user = inner_product(user_embs1, user_embs2)    # [batch_size]


        pos_ratings_item = inner_product(item_embs1, item_embs2)    # [batch_size]
        tot_ratings_item = torch.matmul(item_embs1, torch.transpose(item_embeddings2, 0, 1))  # [batch_size, num_items]

        ssl_logits_user = tot_ratings_user - pos_ratings_user[:, None]                  # [batch_size, num_users]
        ssl_logits_item = tot_ratings_item - pos_ratings_item[:, None]                  # [batch_size, num_users]
        return sup_logits, ssl_logits_user, ssl_logits_item


	sup_logits, ssl_logits_user, ssl_logits_item = self.lightgcn(
	sub_graph1, sub_graph2, bat_users, bat_pos_items, bat_neg_items
	)
	# InfoNCE Loss
	clogits_user = torch.logsumexp(ssl_logits_user / self.ssl_temp, dim=1)
	clogits_item = torch.logsumexp(ssl_logits_item / self.ssl_temp, dim=1)
	infonce_loss = torch.sum(clogits_user + clogits_item)

2、SimGCL写法

    def InfoNCE(self,view1, view2, temperature):
        view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)
        pos_score = (view1 * view2).sum(dim=-1)
        pos_score = torch.exp(pos_score / temperature)
        ttl_score = torch.matmul(view1, view2.transpose(0, 1))
        ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
        cl_loss = -torch.log(pos_score / ttl_score)
        return torch.mean(cl_loss)

你可能感兴趣的:(推荐算法,深度学习)