在mnist上尝试triplet loss (mxnet)

triplet loss

Triplet Loss损失函数在mnist上做相似度计算
triplet loss的核心包括三个部分

  1. anchor/positive/negative
    代表三个输入图,尺寸相同,训练的目标是令anchor和positive距离最小化,同时anchor和negative距离最大化。以FaceRec为例,anchor和positive一般来自同一个人,而negative属于不同的另一个人。
  2. shared models
    通用的卷积模型,输入是单幅图像,输出是1维特征向量
  3. loss
    L i = [ ( f ( x i a ) − f ( x i p ) ) 2 − ( f ( x i a ) − f ( x i n ) ) 2 + α ] L = ∑ i N [ m a x ( L i , 0 ) ] L_i = [ (f(x_i^a) - f(x_i^p))^2 - (f(x_i^a) - f(x_i^n))^2 + \alpha] \\\\ L = \sum_i^N [max( L_i, 0)] Li=[(f(xia)f(xip))2(f(xia)f(xin))2+α]L=iN[max(Li,0)]
    其中 α \alpha α是marginal超参

实验结果

这里给出一个在mnist集合上尝试triplet loss的例子,为了减少计算量,实际也只是采用一部分数据,可以看到效果。
初始时类别分布(一共十个类别,一种颜色代表一个类别)
在mnist上尝试triplet loss (mxnet)_第1张图片遍历十此后的效果
在mnist上尝试triplet loss (mxnet)_第2张图片遍历90次之后
在mnist上尝试triplet loss (mxnet)_第3张图片可以看到同一类逐渐聚集,不同类之间的距离逐渐增大

关键代码

  1. mnist组织成三元组的代码

class TripletMNIST(gluon.data.Dataset):
    def __init__(self,fortrain,dataset_root="C:/dataset/mnist/",resize=sample_size):
        super(TripletMNIST,self).__init__()
        self.data_pairs = {}
        self.total = 0
        self.resize = resize
        if fortrain:
            ds_root = os.path.join(dataset_root,'train')
        else:
            ds_root = os.path.join(dataset_root,"test")
        for rdir, pdirs, names in os.walk(ds_root):
            for name in names:
                basename,ext = os.path.splitext(name)
                if ext != ".jpg":
                    continue
                fullpath = os.path.join(rdir,name)
                label = fullpath.split('\\')[-2]
                label = int(label)
                if smallset_num > 0 and (label in self.data_pairs) and len(self.data_pairs[label]) >= smallset_num:
                    continue 

                self.data_pairs.setdefault(label,[]).append(fullpath)
                self.total += 1
        self.class_num = len(self.data_pairs.keys())
        return
    
    def __len__(self):
        return self.total
        
    def __getitem__(self,idx):
        rds = np.random.randint(0,10000,size = 5)
        rd_anchor_cls, rd_anchor_idx = rds[0], rds[1]
        rd_anchor_cls = rd_anchor_cls % self.class_num
        rd_anchor_idx = rd_anchor_idx % len(self.data_pairs[rd_anchor_cls])
        
        rd_pos_cls, rd_pos_idx = rd_anchor_cls, rds[2]
        rd_pos_cls = rd_pos_cls % self.class_num
        rd_pos_idx = rd_pos_idx % len(self.data_pairs[rd_pos_cls])
        
        rd_neg_cls, rd_neg_idx = rds[3], rds[4]
        rd_neg_cls = rd_neg_cls % self.class_num
        if rd_neg_cls == rd_pos_cls:
            rd_neg_cls = (rd_neg_cls + 1)%self.class_num
        rd_neg_idx = rd_neg_idx % len(self.data_pairs[rd_neg_cls])
        
        img_anchor = cv2.imread(self.data_pairs[rd_anchor_cls][rd_anchor_idx],1)
        img_pos = cv2.imread(self.data_pairs[rd_pos_cls][rd_pos_idx],1)
        img_neg = cv2.imread(self.data_pairs[rd_neg_cls][rd_neg_idx],1)
        
        img_anchor = cv2.resize(img_anchor, self.resize)
        img_pos = cv2.resize(img_pos, self.resize)
        img_neg = cv2.resize(img_neg, self.resize)
        

        img_anchor = np.float32(img_anchor)/255
        img_pos = np.float32(img_pos)/255
        img_neg = np.float32(img_neg)/255
        
        img_anchor = np.transpose(img_anchor,(2,0,1))
        img_pos = np.transpose(img_pos,(2,0,1))
        img_neg = np.transpose(img_neg,(2,0,1))
        
        return (img_anchor, img_pos, img_neg)
  1. 训练代码
       
def train_net(net, train_iter, valid_iter, feat_iter,batch_size, trainer, num_epochs, lr_sch, save_prefix):
    iter_num = 0
    for epoch in range(num_epochs):
        t0 = time.time()
        train_loss = []
        for batch in train_iter:
            iter_num += 1
            trainer.set_learning_rate(lr_sch(iter_num))
            anchor, pos, neg = batch
            #pdb.set_trace()
            X = nd.concat(anchor, pos, neg, dim=0) #combine three inputs along 0-dim to create one batch
            out = X.as_in_context(ctx)
            #print(out.shape)
            with mx.autograd.record(True):
                out = net(out)
                #out = out.as_in_context(mx.cpu())
                out_anchor = out[0:batch_size]
                out_pos = out[batch_size:batch_size*2]
                out_neg = out[batch_size*2 : batch_size*3]
                loss_anchor_pos = (out_anchor - out_pos)**2
                loss_anchor_neg = (out_anchor - out_neg)**2
                #print(loss_anchor_pos.max())
                loss = loss_anchor_pos - loss_anchor_neg
                loss = nd.relu(loss.sum(axis=1) + alpha).mean()
            loss.backward()
            train_loss.append( loss.asnumpy()[0] )
            trainer.step(1)
           # print("\titer {} train loss {}".format(iter_num,np.asarray(train_loss).mean()))
            nd.waitall()
        if (epoch % 10) == 0 and feat_dim == 2:
            show_feat(epoch,net,feat_iter)
        print("epoch {} lr {:>.5} loss {:>.5} cost {:>.3}sec".format(epoch,trainer.learning_rate, \
                                                             np.asarray(train_loss).mean(),time.time() - t0))
    

你可能感兴趣的:(mxnet,codes)