Twin Contrastive Learning for Online Clustering

论文链接

arxiv.org/pdf/2210.11680.pdf

这篇文章利用双路对比学习实现了在线聚类,对我这个方向有一定帮助,以下是此双路对比学习的对比损失函数

实例对比损失

class InstanceLoss(nn.Module):
    """
    Contrastive loss with distributed data parallel support
    """

    LARGE_NUMBER = 1e4

    def __init__(self, tau=0.5, multiplier=2, distributed=False):
        super().__init__()
        self.tau = tau
        self.multiplier = multiplier
        self.distributed = distributed
    ##这里的输入z是增强concat后的
    def forward(self, z, get_map=False):
        n = z.shape[0]
        assert n % self.multiplier == 0

        z = z / np.sqrt(self.tau)
        ##分布式的相关处理
        if self.distributed:
            z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())]
            # all_gather fills the list as [, , ...]
            # TODO: try to rewrite it with pytorch official tools
            z_list = diffdist.functional.all_gather(z_list, z)
            # split it into [, , ..., , , ...]
            z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)]
            # sort it to [, , ...] that simply means [, , ...] as expected below
            z_sorted = []
            for m in range(self.multiplier):
                for i in range(dist.get_world_size()):
                    z_sorted.append(z_list[i * self.multiplier + m])
            z = torch.cat(z_sorted, dim=0)
            n = z.shape[0]

        logits = z @ z.t()
            ## 对角线部分mask
        logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER

        logprob = F.log_softmax(logits, dim=1)

        # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1)
        m = self.multiplier
        labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n
        # remove labels pointet to itself, i.e. (i, i)
        labels = labels.reshape(n, m)[:, 1:].reshape(-1)

        loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1)

        return loss

簇级对比损失

class ClusterLoss(nn.Module):
    """
    Contrastive loss with distributed data parallel support
    """

    LARGE_NUMBER = 1e4

    def __init__(self, tau=1.0, multiplier=2, distributed=False):
        super().__init__()
        self.tau = tau
        self.multiplier = multiplier
        self.distributed = distributed
    ##这里的输入c同上,处理后的
    def forward(self, c, get_map=False):
        n = c.shape[0]
        assert n % self.multiplier == 0

        # c = c / np.sqrt(self.tau)
        ##分布式处理
        if self.distributed:
            c_list = [torch.zeros_like(c) for _ in range(dist.get_world_size())]
            # all_gather fills the list as [, , ...]
            c_list = diffdist.functional.all_gather(c_list, c)
            # split it into [, , ..., , , ...]
            c_list = [chunk for x in c_list for chunk in x.chunk(self.multiplier)]
            # sort it to [, , ...] that simply means [, , ...] as expected below
            c_sorted = []
            for m in range(self.multiplier):
                for i in range(dist.get_world_size()):
                    c_sorted.append(c_list[i * self.multiplier + m])
            c_aug0 = torch.cat(
                c_sorted[: int(self.multiplier * dist.get_world_size() / 2)], dim=0
            )
            c_aug1 = torch.cat(
                c_sorted[int(self.multiplier * dist.get_world_size() / 2) :], dim=0
            )

            p_i = c_aug0.sum(0).view(-1)
            p_i /= p_i.sum()
            en_i = np.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum()
            p_j = c_aug1.sum(0).view(-1)
            p_j /= p_j.sum()
            en_j = np.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum()
            en_loss = en_i + en_j

            c = torch.cat((c_aug0.t(), c_aug1.t()), dim=0)
            n = c.shape[0]

        c = F.normalize(c, p=2, dim=1) / np.sqrt(self.tau)

        logits = c @ c.t()
        logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER

        logprob = F.log_softmax(logits, dim=1)

        # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1)
        m = self.multiplier
        labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n // m, n)) % n
        # remove labels pointet to itself, i.e. (i, i)
        labels = labels.reshape(n, m)[:, 1:].reshape(-1)

        loss = -logprob[np.repeat(np.arange(n), m - 1), labels].sum() / n / (m - 1)

        return loss + en_loss

你可能感兴趣的:(机器学习,人工智能)