Triplet Loss损失函数在mnist上做相似度计算
triplet loss的核心包括三个部分
这里给出一个在mnist集合上尝试triplet loss的例子,为了减少计算量,实际也只是采用一部分数据,可以看到效果。
初始时类别分布(一共十个类别,一种颜色代表一个类别)
遍历十此后的效果
遍历90次之后
可以看到同一类逐渐聚集,不同类之间的距离逐渐增大
class TripletMNIST(gluon.data.Dataset):
def __init__(self,fortrain,dataset_root="C:/dataset/mnist/",resize=sample_size):
super(TripletMNIST,self).__init__()
self.data_pairs = {}
self.total = 0
self.resize = resize
if fortrain:
ds_root = os.path.join(dataset_root,'train')
else:
ds_root = os.path.join(dataset_root,"test")
for rdir, pdirs, names in os.walk(ds_root):
for name in names:
basename,ext = os.path.splitext(name)
if ext != ".jpg":
continue
fullpath = os.path.join(rdir,name)
label = fullpath.split('\\')[-2]
label = int(label)
if smallset_num > 0 and (label in self.data_pairs) and len(self.data_pairs[label]) >= smallset_num:
continue
self.data_pairs.setdefault(label,[]).append(fullpath)
self.total += 1
self.class_num = len(self.data_pairs.keys())
return
def __len__(self):
return self.total
def __getitem__(self,idx):
rds = np.random.randint(0,10000,size = 5)
rd_anchor_cls, rd_anchor_idx = rds[0], rds[1]
rd_anchor_cls = rd_anchor_cls % self.class_num
rd_anchor_idx = rd_anchor_idx % len(self.data_pairs[rd_anchor_cls])
rd_pos_cls, rd_pos_idx = rd_anchor_cls, rds[2]
rd_pos_cls = rd_pos_cls % self.class_num
rd_pos_idx = rd_pos_idx % len(self.data_pairs[rd_pos_cls])
rd_neg_cls, rd_neg_idx = rds[3], rds[4]
rd_neg_cls = rd_neg_cls % self.class_num
if rd_neg_cls == rd_pos_cls:
rd_neg_cls = (rd_neg_cls + 1)%self.class_num
rd_neg_idx = rd_neg_idx % len(self.data_pairs[rd_neg_cls])
img_anchor = cv2.imread(self.data_pairs[rd_anchor_cls][rd_anchor_idx],1)
img_pos = cv2.imread(self.data_pairs[rd_pos_cls][rd_pos_idx],1)
img_neg = cv2.imread(self.data_pairs[rd_neg_cls][rd_neg_idx],1)
img_anchor = cv2.resize(img_anchor, self.resize)
img_pos = cv2.resize(img_pos, self.resize)
img_neg = cv2.resize(img_neg, self.resize)
img_anchor = np.float32(img_anchor)/255
img_pos = np.float32(img_pos)/255
img_neg = np.float32(img_neg)/255
img_anchor = np.transpose(img_anchor,(2,0,1))
img_pos = np.transpose(img_pos,(2,0,1))
img_neg = np.transpose(img_neg,(2,0,1))
return (img_anchor, img_pos, img_neg)
def train_net(net, train_iter, valid_iter, feat_iter,batch_size, trainer, num_epochs, lr_sch, save_prefix):
iter_num = 0
for epoch in range(num_epochs):
t0 = time.time()
train_loss = []
for batch in train_iter:
iter_num += 1
trainer.set_learning_rate(lr_sch(iter_num))
anchor, pos, neg = batch
#pdb.set_trace()
X = nd.concat(anchor, pos, neg, dim=0) #combine three inputs along 0-dim to create one batch
out = X.as_in_context(ctx)
#print(out.shape)
with mx.autograd.record(True):
out = net(out)
#out = out.as_in_context(mx.cpu())
out_anchor = out[0:batch_size]
out_pos = out[batch_size:batch_size*2]
out_neg = out[batch_size*2 : batch_size*3]
loss_anchor_pos = (out_anchor - out_pos)**2
loss_anchor_neg = (out_anchor - out_neg)**2
#print(loss_anchor_pos.max())
loss = loss_anchor_pos - loss_anchor_neg
loss = nd.relu(loss.sum(axis=1) + alpha).mean()
loss.backward()
train_loss.append( loss.asnumpy()[0] )
trainer.step(1)
# print("\titer {} train loss {}".format(iter_num,np.asarray(train_loss).mean()))
nd.waitall()
if (epoch % 10) == 0 and feat_dim == 2:
show_feat(epoch,net,feat_iter)
print("epoch {} lr {:>.5} loss {:>.5} cost {:>.3}sec".format(epoch,trainer.learning_rate, \
np.asarray(train_loss).mean(),time.time() - t0))