论文链接
arxiv.org/pdf/2210.11680.pdf
这篇文章利用双路对比学习实现了在线聚类,对我这个方向有一定帮助,以下是此双路对比学习的对比损失函数
实例对比损失
class InstanceLoss(nn.Module):
"""
Contrastive loss with distributed data parallel support
"""
LARGE_NUMBER = 1e4
def __init__(self, tau=0.5, multiplier=2, distributed=False):
super().__init__()
self.tau = tau
self.multiplier = multiplier
self.distributed = distributed
##这里的输入z是增强concat后的
def forward(self, z, get_map=False):
n = z.shape[0]
assert n % self.multiplier == 0
z = z / np.sqrt(self.tau)
##分布式的相关处理
if self.distributed:
z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())]
# all_gather fills the list as [, , ...]
# TODO: try to rewrite it with pytorch official tools
z_list = diffdist.functional.all_gather(z_list, z)
# split it into [, , ..., , , ...]
z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)]
# sort it to [, , ...] that simply means [, , ...] as expected below
z_sorted = []
for m in range(self.multiplier):
for i in range(dist.get_world_size()):
z_sorted.append(z_list[i * self.multiplier + m])
z = torch.cat(z_sorted, dim=0)
n = z.shape[0]
logits = z @ z.t()
## 对角线部分mask
logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER
logprob = F.log_softmax(logits, dim=1)
# choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1)
m = self.multiplier
labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n
# remove labels pointet to itself, i.e. (i, i)
labels = labels.reshape(n, m)[:, 1:].reshape(-1)
loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1)
return loss
簇级对比损失
class ClusterLoss(nn.Module):
"""
Contrastive loss with distributed data parallel support
"""
LARGE_NUMBER = 1e4
def __init__(self, tau=1.0, multiplier=2, distributed=False):
super().__init__()
self.tau = tau
self.multiplier = multiplier
self.distributed = distributed
##这里的输入c同上,处理后的
def forward(self, c, get_map=False):
n = c.shape[0]
assert n % self.multiplier == 0
# c = c / np.sqrt(self.tau)
##分布式处理
if self.distributed:
c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())]
# all_gather fills the list as [, , ...]
c_list = diffdist.functional.all_gather(c_list, c)
# split it into [, , ..., , , ...]
c_list = [chunk for x in c_list for chunk in x.chunk(self.multiplier)]
# sort it to [, , ...] that simply means [, , ...] as expected below
c_sorted = []
for m in range(self.multiplier):
for i in range(dist.get_world_size()):
c_sorted.append(c_list[i * self.multiplier + m])
c_aug0 = torch.cat(
c_sorted[: int(self.multiplier * dist.get_world_size() / 2)], dim=0
)
c_aug1 = torch.cat(
c_sorted[int(self.multiplier * dist.get_world_size() / 2) :], dim=0
)
p_i = c_aug0.sum(0).view(-1)
p_i /= p_i.sum()
en_i = np.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum()
p_j = c_aug1.sum(0).view(-1)
p_j /= p_j.sum()
en_j = np.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum()
en_loss = en_i + en_j
c = torch.cat((c_aug0.t(), c_aug1.t()), dim=0)
n = c.shape[0]
c = F.normalize(c, p=2, dim=1) / np.sqrt(self.tau)
logits = c @ c.t()
logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER
logprob = F.log_softmax(logits, dim=1)
# choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1)
m = self.multiplier
labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n
# remove labels pointet to itself, i.e. (i, i)
labels = labels.reshape(n, m)[:, 1:].reshape(-1)
loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1)
return loss + en_loss