1.训练一个检测器(bing),要将alfw的数据,生成以下格式的正样本标注格式。
alfw本来是sqlit3的数据库格式,之前生成了.txt的。
利用该.txt文件生成标注文件
import os import cv2 data_base = './JPEGImages/' face_file = open('face_rect.txt') face_file.next() if not os.path.exists('Annotations/flickr/0'): os.makedirs('Annotations/flickr/0') if not os.path.exists('Annotations/flickr/2'): os.makedirs('Annotations/flickr/2') if not os.path.exists('Annotations/flickr/3'): os.makedirs('Annotations/flickr/3') for line in face_file: spline = line.split() imageid = spline[0] path = spline[1] xmin = spline[2] ymin = spline[3] width = spline[4] height = spline[5] xmax = int(xmin) + int(width) ymax = int(ymin) + int(height) filename = path.split('.')[0] img = cv2.imread(data_base + path) img_width = img.shape[1] img_height = img.shape[0] if int(xmin) <= 0: print path, 'xmin', xmin xmin = '1' if int(ymin) <= 0: print path, 'ymin', ymin ymin = '1' if xmax >= img_width: print path, 'xmax', xmax, 'img_width', img_width, 'img_height', img_height xmax = img_width-1 if ymax >= img_height: print path, 'ymax', ymax ymax = img_height-1 if not os.path.exists('Annotations/'+filename+'.yml'): file = open('Annotations/'+filename+'.yml', 'w') print >> file, '%YAML:1.0\n' print >> file, 'annotation:' print >> file, ' folder: ALFW' print >> file, ' filename: \"%s\"'%path print >> file, ' source: {id: %s}' %imageid print >> file, ' owner: {name: zhuqian}' print >> file, ' size: {width: \'%s\', height: \'%s\', depth: \'3\'}'%(img_width, img_height) print >> file, ' segmented: \'0\'' print >> file, ' object:' print >> file, ' - bndbox: {xmin: \'%s\', ymin: \'%s\', xmax: \'%s\', ymax: \'%s\'}'%(xmin, ymin, xmax, ymax) print >> file, ' name: face' print >> file, ' pose: Left' print >> file, ' truncated: \'1\'' print >> file, ' difficult: \'0\'' else: file = open('Annotations/'+filename+'.yml', 'a') print >> file, ' - bndbox: {xmin: \'%s\', ymin: \'%s\', xmax: \'%s\', ymax: \'%s\'}'%(xmin, ymin, xmax, ymax) print >> file, ' name: face' print >> file, ' pose: Left' print >> file, ' truncated: \'1\'' print >> file, ' difficult: \'0\'' #print path file.close() face_file.close()
import os import random images = [] with open('face_rect.txt','r') as face_file: face_file.next() for line in face_file: img_name = line.split()[1].split('.')[0] images.append(img_name) random.shuffle(images) num = len(images) with open('ImageSets/Main/train.txt','w') as train_file: for i in xrange(0, num/4): print >> train_file, images[i] with open('ImageSets/Main/test.txt','w') as test_file: for i in xrange(num/4+1, num): print >> test_file, images[i] with open('ImageSets/Main/class.txt','w') as class_file: print >> class_file, 'face'
3. 将不是jpg格式的图片转成jpg。
import os import cv2 data_base = './JPEGImages/' face_file = open('face_rect.txt') face_file.next() for line in face_file: spline = line.split() path = spline[1] filename = path.split('.')[0] ext = path.split('.')[1] if ext != 'jpg': print path img = cv2.imread(data_base + path) cv2.imwrite(data_base + filename + '.jpg', img) face_file.close()