python 解决data imbalance问题(以casia-webface为例)

Method for Data Imbalance

对于分部不均的数据集使用,从而避免long tail distribution。例如CASIA-WebFace
python 解决data imbalance问题(以casia-webface为例)_第1张图片

import os
import random


def list_of_groups(init_list, children_list_len):
    """
    :param init_list: (list) 放想要分割的list
    :param children_list_len: (list)  想要分割成几份
    :return:
    """
    list_of_groups = zip (*(iter (init_list),) * children_list_len)
    end_list = [list (i) for i in list_of_groups]
    count = len (init_list) % children_list_len
    end_list.append (init_list[-count:]) if count != 0 else end_list
    return end_list


def dataset_split(dataset_path, batch_size, select_num):
    """
    :param dataset_path: (str)存放子文件夹的目录
    :param batch_size: (int)同训练时的batch size
    :param select_num: (int)每个文件夹选择的图片个数
    :return: (list)整个数据集处理后的文件路径,list中还有list
    """
    img_name = []  # 文件夹名
    img_num = []  # 文件夹所含图片个数
    train_path_list = []
    for folders in os.listdir (dataset_path):
        img_name.append (folders)
        img_folder = os.path.join (dataset_path, folders)
        img_num.append (len (os.listdir (img_folder)))

    img_name_sep = list_of_groups (img_name, int (batch_size / select_num))
    for combined_img_folders in img_name_sep:
        for single_img_folder in combined_img_folders:
            img_folder_path = os.path.join (dataset_path, single_img_folder)
            img_folder_imgs = os.listdir (img_folder_path)
            if len (img_folder_imgs) > select_num:
                select_img = random.sample (img_folder_imgs, select_num)
                path = [img_folder_path + '/' + i for i in select_img]
                train_path_list.append (path)

            else:

                print ('Folder {} failed to fetch'.format (single_img_folder))

    return (train_path_list)


if __name__ == '__main__':
    casia_folder = r'E:/FaceNet-pytorch/facenet-pytorch--main/datasets/'

    train_path = dataset_split (dataset_path=casia_folder,
                                batch_size=32,
                                select_num=8)

    print(train_path[0])

python 解决data imbalance问题(以casia-webface为例)_第2张图片

你可能感兴趣的:(python,人工智能,深度学习)