Pytorch堆叠多个损失造成内存爆炸

这几天跑代码的时候,跑着跑着就显示被killed掉(整个人都不好了)。查系统日志发现是内存不够(out of memory),没……办……法……了,直接放弃!当然这是不可能的,笔者怎么可能是个轻言放弃的人呢,哈哈。

言归正传,笔者用的设备是3T的硬盘,跑的程序batch_size=1024,共划分了2000个batch,每跑一个batch内存占用率就会升高,0.5%左右,无奈之下只能一句一句debug,最后后发现是损失累加造成的,如下放所示,代码共计算了三个损失:BPR Loss , Reg Loss , InfoNCE Loss,不能直接累加作为total_loss!而是通过.item()将损失值取出,再累加。

# BPR Loss
bpr_loss = -torch.sum(F.logsigmoid(sup_logits)) 

# Reg Loss
reg_loss = l2_loss(
self.lightgcn.user_embeddings(bat_users),
self.lightgcn.item_embeddings(bat_pos_items),
self.lightgcn.item_embeddings(bat_neg_items),
)
                
# InfoNCE Loss
clogits_user = torch.logsumexp(ssl_logits_user / self.ssl_temp, dim=1)
clogits_item = torch.logsumexp(ssl_logits_item / self.ssl_temp, dim=1)
infonce_loss = torch.sum(clogits_user + clogits_item)
    
loss = bpr_loss + self.ssl_reg * infonce_loss + self.reg * reg_loss

total_loss = total_loss + loss.item() 
total_bpr_loss += bpr_loss.item()
total_reg_loss += self.reg * reg_loss.item()
               
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

你可能感兴趣的:(pytorch,人工智能,python)