用TensorLayer随机裁剪图片并修改对应的xml

随机裁剪图像

# -*- coding: utf-8 -*-
import os
import tensorlayer as tl
import xml.etree.cElementTree as ET
from lxml.etree import Element, SubElement, tostring
from xml.dom.minidom import parseString
from PIL import Image
##把数据转换为 ann_list的格式,ann_list=[类别,位置信息【list】]
def trans_img(img,img_xml,classes_dict):
    ann_list=[]
    class_list=[]
    tree = ET.parse(img_xml)
    root = tree.getroot()
    #找到图像的w,h
    size = root.find('size')
    xml_width = int(size.find('width').text)
    xml_height = int(size.find('height').text)
    for obj in root.iter('object'):
        position=[]
        #类别信息
        xml_name = str(obj.find('name').text)
        if xml_name in classes_dict.keys():
            xml_class=classes_dict[xml_name]
            xml_box = obj.find('bndbox')
            _xmin = int(xml_box.find('xmin').text)
            _xmax = int(xml_box.find('xmax').text)
            _ymin = int(xml_box.find('ymin').text)
            _ymax = int(xml_box.find('ymax').text)
            c_x=((_xmin+_xmax)/2.0)/xml_width
            c_y = ((_ymin + _ymax) / 2.0) / xml_height
            o_w=(_xmax-_xmin)/xml_width
            o_h = (_ymax - _ymin) / xml_height
            position.append(c_x)
            position.append(c_y)
            position.append(o_w)
            position.append(o_h)
            class_list.append(xml_class)
            ann_list.append(position)
    return class_list,ann_list

#得到所有类别的字典,总类别
def get_classdict(file_path):
    classes = []
    classes_dict = {}
    for line in open(file_path):
       classes.append(line.strip("\n"))
    for i in range(len(classes)):
        classes_dict[classes[i]]=i+1
    return classes,classes_dict

#裁剪图片并保存对应的xml,w,h对应裁剪的宽,长
def crop_img(rootpath,savepath,classes,classes_dict,img,img_xml,w,h):
    anns = []
    #得到总类别
    #sku包含的类别,归一化后得到坐标信息 ,is_random可以修改,默认裁剪中间
    img1=os.path.join(rootpath,img)
    img1_xml=os.path.join(rootpath,img_xml)
    cla, ann = trans_img(img1, img1_xml, classes_dict)
    image = tl.vis.read_image(img1)
    #裁剪图片
    im_crop, clas, coords = tl.prepro.obj_box_crop(image, cla,
                                                   ann, wrg=w, hrg=h, is_rescale=True, is_center=True,
                                                   is_random=False)
    tl.vis.save_image(im_crop,os.path.join(savepath,str(w)+"_"+str(h)+"_crop_"+img))
    image = im_crop.copy()
    imh, imw = image.shape[0:2]
    # clas类别信息 anns新的坐标信息,im_crop裁剪的图片
    for i in range(len(coords)):
        pos = []
        x, y, x2, y2 = tl.prepro.obj_box_coord_centroid_to_upleft_butright(coords[i])
        x, y, x2, y2 = tl.prepro.obj_box_coord_scale_to_pixelunit([x, y, x2, y2], (imh, imw))
        pos.append(x),pos.append(y),pos.append(x2),pos.append(y2)
        anns.append(pos)
    ###新建xml###########
    node_root = Element('annotation')
    node_folder = SubElement(node_root, 'folder')
    node_folder.text = '1'
    node_filename = SubElement(node_root, 'filename')
    node_filename.text ="crop_"+img
    node_size = SubElement(node_root, 'size')
    node_width = SubElement(node_size, 'width')
    node_width.text = str(w)
    node_height = SubElement(node_size, 'height')
    node_height.text = str(h)
    node_depth = SubElement(node_size, 'depth')
    node_depth.text = '3'
    for i in range(len(clas)):
        node_object = SubElement(node_root, 'object')
        node_name = SubElement(node_object, 'name')
        node_name.text = str(classes[clas[i]-1])
        node_difficult = SubElement(node_object, 'difficult')
        node_difficult.text = '0'
        node_bndbox = SubElement(node_object, 'bndbox')
        node_xmin = SubElement(node_bndbox, 'xmin')
        node_xmin.text = str(anns[i][0])
        node_ymin = SubElement(node_bndbox, 'ymin')
        node_ymin.text = str(anns[i][1])
        node_xmax = SubElement(node_bndbox, 'xmax')
        node_xmax.text = str(anns[i][2])
        node_ymax = SubElement(node_bndbox, 'ymax')
        node_ymax.text = str(anns[i][3])
    xml = tostring(node_root, pretty_print=True)  # 格式化显示,该换行的换行
    img_newxml=os.path.join(savepath,str(w)+"_"+str(h)+"_crop_"+img_xml)
    file_object = open(img_newxml, 'wb')
    file_object.write(xml)
    file_object.close()



if __name__ == "__main__":
    # img=r"E:\data_set\ce\201810081620161.jpg"
    # img_xml=r"E:\data_set\ce\201810081620161.xml"
    file_path=r"E:\data_set\hd.txt"
    rootpath=r"E:\data_set\HD_data\delete_small_hd_train"
    savepath=r"E:\data_set\HD_data\480_640"
    w=480
    h=640
    classes,classes_dict=get_classdict(file_path)
    list = os.listdir(rootpath)
    for i in range(0, len(list)):
        path = os.path.join(rootpath, list[i])
        if os.path.isfile(path):
            if (list[i].endswith("jpg") or list[i].endswith("JPG")):
                img=list[i]
                img_xml=list[i].split(".")[0] + ".xml"
                im=Image.open(path)
                print(img)
                if im.size[0]>w and im.size[1]>h:
                    crop_img(rootpath, savepath, classes, classes_dict,img, img_xml,w, h)






你可能感兴趣的:(用TensorLayer随机裁剪图片并修改对应的xml)