caffe实现多标签输入(multilabel、multitask)

本文采用多label的lmdb+Slice Layer的方法

  • 生成多label的lmdb
# coding:utf8
import numpy as np
import os
import caffe
import lmdb
# from PIL import Image
import cv2
import numpy as np
import sys
# Make sure that caffe is on the python path:
caffe_root = 'D:/caffe/caffe-master'
sys.path.insert(0, caffe_root + '/python')

####################pre-treatment############################
#txt with labels eg. (0001.jpg 2 5)
file_input = open('C:/Users/ASUS/Desktop/Siamese_ResNet/Siamese_ResNet_code/label/train.txt', 'r')
img1_list = []
img2_list = []
label_list = []
for line in file_input:
    content = line.strip()         #移除字符串头尾指定的字符(默认为空格)
    content = content.split('\t')      #指定分隔符对字符串进行切片
    img1_list.append(str(content[0]))
    img2_list.append(str(content[1]))
    label_list.append(int(content[2]))
    del content
file_input.close()      #关闭文件
####################train data(images)############################
#your data lmdb path
#注意一定要先删除之前生成的lmdb,因为lmdb会在之前的数据基础上新增数据,而不会先清空
#os.system('rm -rf  ' + your data(images) lmdb path)
# create the lmdb file
# map_size指的是数据库的最大容量,根据需求设置
in_db=lmdb.open('C:/Users/ASUS/Desktop/Siamese_ResNet/Siamese_ResNet_code/siamese_resnet_image_train_lmdb',map_size=int(1e10))
# 创建操作数据库句柄
with in_db.begin(write=True) as in_txn:
    for in_idx, in_ in enumerate(img1_list):     #遍历图像列表
       target_img = np.zeros((6, 72, 72))   #产生6维72*72的多维数组
       im1_file='C:/Users/ASUS/Desktop/Siamese_ResNet/Siamese_ResNet_code/seg_image/'+in_
       im2_file='C:/Users/ASUS/Desktop/Siamese_ResNet/Siamese_ResNet_code/seg_image/'+img2_list[in_idx]
       # im=Image.open(im_file)
       im1 = cv2.imread(im1_file)
       im2 = cv2.imread(im2_file)
       #im = im.resize((w,h),Image.BILINEAR)#放缩图片,分类一般用
       #双线性BILINEAR,分割一般用最近邻NEAREST,**注意准备测试数据时一定要一致**
       # im=np.array(im) # im: (w,h)RGB->(h,w,3)RGB
       # im=im[:,:,::-1]#把im的RGB调整为BGR
       #im1 = im1.transpose((2, 0, 1)) #把height*width*channel调整为channel*height*width
       #im2 = im2.transpose((2, 0, 1))
       target_img[0, 0, 0] = im1
       target_img[3, 0, 0] = im2
       im_dat=caffe.io.array_to_datum(target_img)  ##将图像数据整合为一个数据项
       # '{:0>10d}'.format(in_idx):
       #      lmdb的每一个数据都是由键值对构成的,
       #      因此生成一个用递增顺序排列的定长唯一的key
       in_txn.put('{:0>10d}'.format(in_idx),im_dat.SerializeToString())     #调用句柄,写入内存
       print 'data train: {} [{}/{}]'.format(in_, in_idx+1, len(img1_list))
       #print 'data train: {} [{}/{}]'.format(img2_list[in_idx], in_idx+1, len(img2_list))
       del im1_file, im2_file, im1, im2, im_dat
in_db.close()
print 'train data(images) are done!'

######train data of label################
#your labels lmdb path
in_db = lmdb.open('C:/Users/ASUS/Desktop/Siamese_ResNet/Siamese_ResNet_code/siamese_resnet_label_train_lmdb',map_size=int(1e10))
with in_db.begin(write=True) as in_txn:
    for in_idx, in_ in enumerate(img1_list):
        target_label = np.zeros((1, 1, 1))
        target_label[0, 0, 0] = label_list[in_idx]
        label_data = caffe.io.array_to_datum(target_label)
        in_txn.put('{:0>10d}'.format(in_idx), label_data.SerializeToString())
        print 'label train: {} [{}/{}]'.format(in_, in_idx+1, len(img1_list))
        del label_data, target_label
in_db.close()
print 'train labels are done!'




你可能感兴趣的:(caffe)