基于中心对齐的领域泛化损失

class Center_loss(nn.Module):
    def __init__(self,src_class):
        super(Center_loss, self).__init__()

        self.n_class=src_class
        self.MSELoss = nn.MSELoss()  # (x-y)^2
        self.MSELoss = self.MSELoss.cuda()




    def forward(self, feature,label,domain_index):
        domian_index_1 = (domain_index == 0)
        domian_index_2 = (domain_index == 1)
        domian_index_3 = (domain_index == 2)

        label1 = label[domian_index_1]
        label2 = label[domian_index_2]
        label3 = label[domian_index_3]

        feature1 = feature[domian_index_1]
        feature2 = feature[domian_index_2]
        feature3 = feature[domian_index_3]

        s1, d = feature1.shape
        s2, d = feature2.shape
        s3, d = feature3.shape

        ones1 = t.ones_like(label1, dtype=t.float)
        ones2 = t.ones_like(label2, dtype=t.float)
        ones3 = t.ones_like(label3, dtype=t.float)

        zeros = t.zeros(self.n_class)

        zeros = zeros.cuda()

        n_classes1 = zeros.scatter_add(0, label1, ones1)
        n_classes2 = zeros.scatter_add(0, label2, ones2)
        n_classes3 = zeros.scatter_add(0, label3, ones3)

        # image number cannot be 0, when calculating centroids
        s_ones1 = t.ones_like(n_classes1)
        s_ones2 = t.ones_like(n_classes2)
        s_ones3 = t.ones_like(n_classes3)

        n_classes1 = t.max(n_classes1, s_ones1)
        n_classes2 = t.max(n_classes2, s_ones2)
        n_classes3 = t.max(n_classes3, s_ones3)

        # calculating centroids, sum and divide
        zeros = t.zeros(self.n_class, d)

        zeros = zeros.cuda()

        s_sum_feature1 = zeros.scatter_add(0, t.transpose(label1.repeat(d, 1), 1, 0), feature1)
        s_sum_feature2 = zeros.scatter_add(0, t.transpose(label2.repeat(d, 1), 1, 0), feature2)
        s_sum_feature3 = zeros.scatter_add(0, t.transpose(label3.repeat(d, 1), 1, 0), feature3)

        current_s_centroid1 = t.div(s_sum_feature1, n_classes1.view(self.n_class, 1))
        current_s_centroid2 = t.div(s_sum_feature2, n_classes2.view(self.n_class, 1))
        current_s_centroid3 = t.div(s_sum_feature3, n_classes3.view(self.n_class, 1))

        semantic_loss = self.MSELoss(current_s_centroid1, current_s_centroid2) + \
                        self.MSELoss(current_s_centroid1, current_s_centroid3) + \
                        self.MSELoss(current_s_centroid2, current_s_centroid3)


        return semantic_loss


   

你可能感兴趣的:(域泛化,python,开发语言,后端)