分类网络位置测试图片生成

import os
import random

import cv2
import numpy as np

if __name__ == '__main__':

    os.makedirs("imgs",exist_ok=True)
    datas=[]
    for i in range(16*32):
        arr3 = np.random.randint(0, 2)
        img_name=f"imgs/{i}.jpg"
        values=["/"+img_name]
        if arr3:
            values.append(2)
            img_1 = np.ones((64, 64), dtype=np.uint8) * 127
        else:
            img_1 = np.ones((64, 64), dtype=np.uint8) * 255
            values.append(1)

        arr3 = np.random.randint(0,3)
        img=None
        if arr3==2:
            img = np.ones((64, 64), dtype=np.uint8) * 127
        elif arr3==1:
            img = np.ones((64, 64), dtype=np.uint8) * 255

        values.append(arr3)
        if img is not None:

            arr3 = np.random.randint(0, 2)
            if arr3:
                img_a=np.hstack((img_1,img))
            else:
                img_a = np.vstack((img_1, img))
        else:
            img_a=img_1

        arr3 = np.random.randint(0, 2)
        values.append(arr3)
        datas.append(values)
        cv2.imwrite(img_name,img_a)

    import csv

    headers = ['img_path', 'col_a', 'col_b', 'val/train']


    with open('label.csv', 'w', newline='') as f:
        f_csv = csv.writer(f)
        f_csv.writerow(headers)
        f_csv.writerows(datas)



     # np.vstack((img, img_2))

你可能感兴趣的:(python基础,分类,python)