MTCNN-Caffe(二)生成训练集、验证集的list,混合

上篇博客讲述了在训练caffe模型时,生成了3个txt文件,再分成训练集和验证集后要生成整体的train/val.txt,该py文件将整理好train与valid。默认比例是pos:neg:part:landmark为1:3:1:0

#!/usr/bin/env python
"""
classify.py is an out-of-the-box image classifer callable from the command line.

By default it configures and runs the Caffe reference ImageNet model.
"""
import os
import sys
import argparse
import glob
import time
import random

def view_bar(num, total):
    rate = float(num) / total
    rate_num = int(rate * 100)+1
    r = '\r[%s%s]%d%%' % ("#"*rate_num, " "*(100-rate_num), rate_num, )
    sys.stdout.write(r)
    sys.stdout.flush()

def main(argv):
    pycaffe_dir = os.path.dirname(__file__)

    parser = argparse.ArgumentParser()
    # Required arguments: input and output files.
    parser.add_argument(
        "pos_file",
        type=str,
        help="positive sample list"
    )
    parser.add_argument(
        "neg_file",
        type=str,
        help="negative sample list"
    )    
    parser.add_argument(
        "part_file",
        type=str,
        help="partial sample list"
    )
    parser.add_argument(
        "landmark_file",
        type=str,
        help="landmark sample list"
    )
    parser.add_argument(
        "sample_percents",
        type=str,
        default='1:3:1:2',
        help="landmark sample list"
    )
    parser.add_argument(
        "output_file",
        type=str,
        help="output list"
    )

    args = parser.parse_args()
    sample_percents = [int(s) for s in args.sample_percents.split(':')]
    if len(sample_percents) != 4:
        print("sample percents must have 4 numbers")
        exit(0)

    pos_list = []
    with open(args.pos_file) as f:
        pos_list = f.readlines()
        random.shuffle(pos_list)

    neg_list = []
    with open(args.neg_file) as f:
        neg_list = f.readlines()
        random.shuffle(neg_list)
    
    part_list = []
    with open(args.part_file) as f:
        part_list = f.readlines()
        random.shuffle(part_list)
        
    landmark_list = []
    with open(args.landmark_file) as f:
        landmark_list = f.readlines()
        random.shuffle(landmark_list)

    f1 = open(args.output_file, 'w')
    pos_idx = 0
    neg_idx = 0
    part_idx = 0
    landmark_idx = 0

    total_num = len(pos_list)
    for pos in pos_list:
        view_bar(pos_idx,total_num)
        pos_idx += 1

        f1.write(pos.strip()+"\n")

        for idx in range(0, sample_percents[1]):
            f1.write(neg_list[neg_idx%len(neg_list)].strip()+"\n")
            neg_idx += 1
            if neg_idx == len(neg_list):
                random.shuffle(neg_list)
                neg_idx = 0
        
        for idx in range(0, sample_percents[2]):
            f1.write(part_list[part_idx%len(part_list)].strip()+"\n")
            part_idx += 1
            if part_idx == len(part_list):
                random.shuffle(part_list)
                part_idx = 0
            
        for idx in range(0, sample_percents[3]):
            f1.write(landmark_list[landmark_idx%len(landmark_list)].strip()+"\n")
            landmark_idx += 1
            if landmark_idx == len(landmark_list):
                random.shuffle(landmark_list)
                landmark_idx = 0

    f1.close()

if __name__ == '__main__':
    main(sys.argv)

 

你可能感兴趣的:(MTCNN-Caffe(二)生成训练集、验证集的list,混合)