FaceNet-pytorch(fixing data imbalance-CASIA)

# ------------------------------------------------#
#   进行训练前需要利用这个文件生成cls_train.txt
# ------------------------------------------------#
import os
import random


def txt_annotation(datasets_path, select_num):
    '''
    :param datasets_path: 数据集根目录
    :param select_num: 每一个folder选择的文件数目
    :return: cls_train.txt
    '''

    types_name = os.listdir (datasets_path)

    types_name = sorted (types_name)

    # print (types_name)
    list_file = open ('cls_train.txt', 'w')
    for cls_id, type_name in enumerate (types_name):
        photos_path = os.path.join (datasets_path, type_name)
        if not os.path.isdir (photos_path) or select_num > len (os.listdir (photos_path)):
            print (
                'pass folder {} with {} numbers , required {} '.format (type_name, len (os.listdir (photos_path)),
                                                                        select_num))
            continue

        photos_name = os.listdir (photos_path)
        photos_name = random.sample (photos_name, select_num)
        # print(photos_name)

        for photo_name in photos_name:
            # print (cls_id)
            list_file.write (
                str (cls_id) + ";" + '%s' % (os.path.join (os.path.abspath (datasets_path), type_name, photo_name)))
            list_file.write ('\n')
    list_file.close ()


if __name__ == "__main__":
    dataset_path = r'test_dataset'
    select_num = 32

    txt_annotation (dataset_path, select_num)

效果如下:
FaceNet-pytorch(fixing data imbalance-CASIA)_第1张图片
FaceNet-pytorch(fixing data imbalance-CASIA)_第2张图片

你可能感兴趣的:(FaceNet-pytorch(fixing data imbalance-CASIA))