class Metric_loss(nn.Module):
def __init__(self,src_class):
super(Metric_loss, self).__init__()
def forward(self, s_feature,s_labels):
n, d = s_feature.shape
# get labels
# image number in each class
ones = torch.ones_like(s_labels, dtype=torch.float)
zeros = torch.zeros(self.n_class)
zeros = zeros.cuda()
s_n_classes = zeros.scatter_add(0, s_labels, ones)
# image number cannot be 0, when calculating centroids
ones = torch.ones_like(s_n_classes)
s_n_classes = torch.max(s_n_classes, ones)
# calculating centroids, sum and divide
zeros = torch.zeros(self.n_class, d)
zeros = zeros.cuda()
s_sum_feature = zeros.scatter_add(0, torch.transpose(s_labels.repeat(d, 1), 1, 0), s_feature)
s_centroid = torch.div(s_sum_feature, s_n_classes.view(self.n_class, 1))
# calculating inter distance
temp = torch.zeros((n, d)).cuda()
for i in range(n):
temp[i] = s_centroid[s_labels[i]]
s_all_centroid=s_all_centroid.repeat(self.n_class, 1)
# inter_loss = torch.norm(s_all_centroid- s_centroid, p=1, dim=0).max()
# intra_loss = torch.norm(temp-s_feature, p=1, dim=0).max()
inter_loss = torch.norm(s_all_centroid- s_centroid, p=1, dim=0).sum()
intra_loss = torch.norm(temp-s_feature, p=1, dim=0).sum()
inter_loss = inter_loss/d
return inter_loss,intra_loss
selector = BatchHardTripletSelector()
anchor, pos, neg = selector(feature, src_label)
triplet_loss = TripletLoss(margin=1).cuda()
triplet = triplet_loss(anchor, pos, neg)
class TripletLoss(nn.Module):
Compute normal triplet loss or soft margin triplet loss given triplets
def __init__(self, margin = None):
super(TripletLoss, self).__init__()
self.margin = margin
if self.margin is None: # use soft-margin
self.Loss = nn.SoftMarginLoss()
self.Loss = nn.TripletMarginLoss(margin = margin, p = 2)
def forward(self, anchor, pos, neg):
if self.margin is None:
num_samples = anchor.shape[0]
y = t.ones((num_samples, 1)).view(-1)
if anchor.is_cuda: y = y.cuda()
ap_dist = t.norm(anchor - pos, 2, dim = 1).view(-1)
an_dist = t.norm(anchor - neg, 2, dim = 1).view(-1)
loss = self.Loss(an_dist - ap_dist, y)
loss = self.Loss(anchor, pos, neg)
return loss
class BatchHardTripletSelector(object):
a selector to generate hard batch embeddings from the embedded batch
def __init__(self, *args, **kwargs):
super(BatchHardTripletSelector, self).__init__()
def __call__(self, embeds, labels):
dist_mtx = pdist_torch(embeds, embeds).detach().cpu().numpy()# 计算距离
labels = labels.contiguous().cpu().numpy().reshape((-1, 1))# 断开连接,深拷贝
num = labels.shape[0]
dia_inds = np.diag_indices(num)#返回对角线索引
lb_eqs = labels == labels.T
lb_eqs[dia_inds] = False
dist_same = dist_mtx.copy()
dist_same[lb_eqs == False] = -np.inf #负正无穷大的浮点表示
pos_idxs = np.argmax(dist_same, axis = 1)
dist_diff = dist_mtx.copy()
lb_eqs[dia_inds] = True
dist_diff[lb_eqs == True] = np.inf
neg_idxs = np.argmin(dist_diff, axis = 1)
pos = embeds[pos_idxs].contiguous().view(num, -1)
neg = embeds[neg_idxs].contiguous().view(num, -1)
return embeds, pos, neg