SHHB_train数据集txt格式标记头部文本转换为coco格式的txt文本

import glob
import os
import numpy as np
import shutil
from PIL import Image

START_BOUNDING_BOX_ID = 0

def val_train(SHHB_img, SHHB_txt,pro_train,pro_valid):
train_img = ‘D:\project\PyTorch-YOLOv3-master\data\custom\images\SHHB_train\’ #分离的训练图片路径
val_img = ‘D:\project\PyTorch-YOLOv3-master\data\custom\images\SHHB_valid\’ #分离的验证集图片路径

# train_img = 'D:\\project\\zzzzzzz\\ag_image\\image\\SHHB_train\\'  # 分离的训练图片路径
# val_img = 'D:\\project\\zzzzzzz\\ag_image\\image\\SHHB_valid\\'       #分离的验证集图片路径
if os.path.exists(train_img):         #创建文件夹
    shutil.rmtree(train_img)
os.makedirs(train_img)
if os.path.exists(val_img):
    shutil.rmtree(val_img)
os.makedirs(val_img)

train_ratio = 0.9        #保留90%的图片为训练图片

train_list = glob.glob(SHHB_img + '/*.jpg')
train_list = np.sort(train_list)
np.random.seed(100)
np.random.shuffle(train_list)
train_num = int(len(train_list) * train_ratio)
img_list_train = train_list[:train_num]
img_list_val = train_list[train_num:]

f1 = open('SHHB_train.txt','w')             #把数据的路径写入到文件夹
# f1 = open('D:\\project\\zzzzzzz\\ag_image\\SHHB_train.txt','w')             #把数据的路径写入到文件夹
f = open(pro_train, 'w')
for img in img_list_train:
    f1.write(img + '\n')
    shutil.copy(img, train_img)       #图片训练和验证数据分离
    n_img = img.split('G_')
    n_img = n_img[1].split('.')                   #n_img[0]为图片的序列号
    for line in open(SHHB_txt,'r'):
        n_txt = line.split('G_')
        n_txt = n_txt[1].split('.')                #n_txt[0]为图片的序列号
        if n_txt[0] == n_img[0]:
            f.write(line)
f1.close()
f.close()
f2 = open('SHHB_valid.txt','w')
f = open(pro_valid, 'w')
for img in img_list_val:
    f2.write(img + '\n')
    shutil.copy(img,val_img)
    n_img = img.split('G_')
    n_img = n_img[1].split('.')  # n_img[0]为图片的序列号
    for line in open(SHHB_txt, 'r'):
        n_txt = line.split('G_')
        n_txt = n_txt[1].split('.')  # n_txt[0]为图片的序列号
        if n_txt[0] == n_img[0]:
            f.write(line)
f2.close()
f.close()

def convert(pro_train,pro_valid):
train_txt = ‘D:\project\PyTorch-YOLOv3-master\data\custom\labels\SHHB_train\’ #分离的训练txt文件
val_txt = ‘D:\project\PyTorch-YOLOv3-master\data\custom\labels\SHHB_valid\’ #分离的验证txt文件
# train_txt = ‘D:\project\zzzzzzz\ag_image\lable\SHHB_train\’ #分离的训练txt文件
# val_txt = ‘D:\project\zzzzzzz\ag_image\lable\SHHB_valid\’ #分离的验证txt文件
if os.path.exists(train_txt):
shutil.rmtree(train_txt)
os.makedirs(train_txt)
if os.path.exists(val_txt):
shutil.rmtree(val_txt)
os.makedirs(val_txt)

# categories = pre_define_categories.copy()
# bnd_id = START_BOUNDING_BOX_ID
# all_categories = {}
flag_continue = [1,2,3,4,5]            #设置跳出循环
flag = 0
#训练集
for line in open(pro_train,'r'):
    linefeed = False                          #消除第一行换行
    line_feed = False              #如果上一个人头信息舍去,则不用换行
    line = line.strip('\n')
    n_txt = line.split('G_')
    n_txt = n_txt[1].split('.jpg')
    # n_txt = n_txt[1].split(' ')
    txt_id = n_txt[0]               #txt_id获取txt文件序号
    # txt_id = n_txt[0].split('.')               #txt_id获取txt文件序号
    # txt_id = txt_id[0]
    txt_content = n_txt[1].split(' ')
    # n_txt.remove(n_txt[0])
    # txt_content = n_txt
    txt_content.remove(txt_content[0])
    # txt_content.strip('\n')
    print(txt_id)
    txt_content = list(map(int, txt_content))         #把字符转换成int类型
    txt_total = txt_content[0] * 5        #文本每一行数字的总数
    txt_content.remove(txt_content[0])  # 移除第一列的人头数
    f = open(train_txt + '\\' + 'IMG_' + txt_id + '.txt', 'w')        #写入与图片对应的txt文件
    for i in range(txt_total):
        if i % 5 == 0 and linefeed and line_feed:
            f.write('\n')

        if i % 5 == 0:
            if txt_content[i + 3] <= 5 or txt_content[i + 4] <= 5:
                flag = 5
            else:
                txt_content[i] == 0
        # elif (i % 5 == 1) and (txt_content[i + 2] <=5 or txt_content[i + 3] <= 5):
        #     flag = 4
        if flag in flag_continue:
            flag = flag - 1
            line_feed = False
            continue
        else:                #把人头信息写入对应的txt文件中
            linefeed = True
            line_feed = True
            filename = 'D:\\project\\PyTorch-YOLOv3-master\\data\\custom\\images\\SHHB_train\\' + 'IMG_' + txt_id + '.jpg'
            # filename = 'D:\\project\\zzzzzzz\\ag_image\\image\\SHHB_train\\' + 'IMG_' + txt_id + '.jpg'
            img = Image.open(filename)           #获得图片的size
            if i % 5 == 1:
                f.write(str(((txt_content[i] * 2 + txt_content[i + 2]) / 2) / img.size[0]) + ' ')
            elif i % 5 == 2:
                f.write(str(((txt_content[i] * 2 + txt_content[i + 2]) / 2) / img.size[1]) + ' ')
            # elif i % 5 == 3:
            #     f.write(str(txt_content[i] / img.size[0]) + ' ')
            # elif i % 5 == 4:
            #     f.write(str(txt_content[i] / img.size[1]) + ' ')
            elif i % 5 == 3:
                f.write(str(txt_content[i] / ((img.size[0] + img.size[1]) / 2)) + ' ')
            elif i % 5 == 4:
                f.write(str(txt_content[i] / ((img.size[0] + img.size[1]) / 2)) + ' ')
            # if i % 5 == 1 or i % 5 == 3:
            #     f.write(str(txt_content[i] / img.size[0]) + ' ')
            # elif i % 5 == 2 or i % 5 == 4:
            #     f.write(str(txt_content[i] / img.size[1]) + ' ')
            else:
                f.write(str(txt_content[i]*0) + ' ')
    f.close()

    #验证集
for line in open(pro_valid, 'r'):
    linefeed = False  # 消除第一行换行
    line_feed = False  # 如果上一个人头信息舍去,则不用换行
    line = line.strip('\n')
    n_txt = line.split('G_')
    n_txt = n_txt[1].split('.jpg')
    txt_id = n_txt[0]  # txt_id获取txt文件序号
    txt_content = n_txt[1].split(' ')
    txt_content.remove(txt_content[0])
    # txt_content.strip('\n')
    print(txt_id)
    txt_content = list(map(int, txt_content))
    txt_total = txt_content[0] * 5  # 文本每一行数字的总数
    txt_content.remove(txt_content[0])  # 移除第一列的人头数
    f = open(val_txt + '\\' + 'IMG_' + txt_id + '.txt', 'w')  # 写入与图片对应的txt文件
    for i in range(txt_total):
        if i % 5 == 0 and linefeed and line_feed:
            f.write('\n')

        if i % 5 == 0:
            if txt_content[i + 3] <= 5 or txt_content[i + 4] <= 5:
                flag = 5
            else:
                txt_content[i] == 0
        if flag in flag_continue:
            flag = flag - 1
            line_feed = False
            continue
        else:  # 把人头信息写入对应的txt文件中
            linefeed = True
            line_feed = True
            filename = 'D:\\project\\PyTorch-YOLOv3-master\\data\\custom\\images\\SHHB_valid\\' + 'IMG_' + txt_id + '.jpg'
            img = Image.open(filename)  # 获得图片的size
            if i % 5 == 1:
                f.write(str(((txt_content[i] * 2 + txt_content[i + 2]) / 2) / img.size[0]) + ' ')
            elif i % 5 == 2:
                f.write(str(((txt_content[i] * 2 + txt_content[i + 2]) / 2) / img.size[1]) + ' ')
            # elif i % 5 == 3:
            #     f.write(str(txt_content[i] / img.size[0]) + ' ')
            # elif i % 5 == 4:
            #     f.write(str(txt_content[i] / img.size[1]) + ' ')
            elif i % 5 == 3:
                f.write(str(txt_content[i] / ((img.size[0] + img.size[1]) / 2)) + ' ')
            elif i % 5 == 4:
                f.write(str(txt_content[i] / ((img.size[0] + img.size[1]) / 2)) + ' ')
            # if i % 5 == 1 or i % 5 == 3:
            #     f.write(str(txt_content[i] / img.size[0]) + ' ')
            # elif i % 5 == 2 or i % 5 == 4:
            #     f.write(str(txt_content[i] / img.size[1]) + ' ')
            else:
                f.write(str(txt_content[i]*0) + ' ')
    f.close()

if name == ‘main’:
SHHB_img = ‘D:\project\crowd couting\part_B_final\train_data\images’ # 上海图片数据集
SHHB_txt = ‘D:\project\crowd couting\part_B_final\train_data\Part_B_train.txt’ # txt数据文件
pro_train = ‘D:\project\PyTorch-YOLOv3-master\pro_train.txt’ # 把txt中的文件分为训练和验证集写到此文件中
pro_valid = ‘D:\project\PyTorch-YOLOv3-master\pro_valid.txt’
# pro_train = ‘D:\project\zzzzzzz\ag_image\pro_train.txt’ # 把txt中的文件分为训练和验证集写到此文件中
# pro_valid = ‘D:\project\zzzzzzz\ag_image\pro_valid.txt’

# classes = ['head']
# pre_define_categories = {}
# for i, cls in enumerate(classes):
#     pre_define_categories[cls] = i
# only_care_pre_define_categories = True
# only_care_pre_define_categories = False
val_train(SHHB_img, SHHB_txt, pro_train, pro_valid)
convert(pro_train, pro_valid)

你可能感兴趣的:(SHHB_train数据集txt格式标记头部文本转换为coco格式的txt文本)