随机裁剪图像
# -*- coding: utf-8 -*-
import os
import tensorlayer as tl
import xml.etree.cElementTree as ET
from lxml.etree import Element, SubElement, tostring
from xml.dom.minidom import parseString
from PIL import Image
##把数据转换为 ann_list的格式,ann_list=[类别,位置信息【list】]
def trans_img(img,img_xml,classes_dict):
ann_list=[]
class_list=[]
tree = ET.parse(img_xml)
root = tree.getroot()
#找到图像的w,h
size = root.find('size')
xml_width = int(size.find('width').text)
xml_height = int(size.find('height').text)
for obj in root.iter('object'):
position=[]
#类别信息
xml_name = str(obj.find('name').text)
if xml_name in classes_dict.keys():
xml_class=classes_dict[xml_name]
xml_box = obj.find('bndbox')
_xmin = int(xml_box.find('xmin').text)
_xmax = int(xml_box.find('xmax').text)
_ymin = int(xml_box.find('ymin').text)
_ymax = int(xml_box.find('ymax').text)
c_x=((_xmin+_xmax)/2.0)/xml_width
c_y = ((_ymin + _ymax) / 2.0) / xml_height
o_w=(_xmax-_xmin)/xml_width
o_h = (_ymax - _ymin) / xml_height
position.append(c_x)
position.append(c_y)
position.append(o_w)
position.append(o_h)
class_list.append(xml_class)
ann_list.append(position)
return class_list,ann_list
#得到所有类别的字典,总类别
def get_classdict(file_path):
classes = []
classes_dict = {}
for line in open(file_path):
classes.append(line.strip("\n"))
for i in range(len(classes)):
classes_dict[classes[i]]=i+1
return classes,classes_dict
#裁剪图片并保存对应的xml,w,h对应裁剪的宽,长
def crop_img(rootpath,savepath,classes,classes_dict,img,img_xml,w,h):
anns = []
#得到总类别
#sku包含的类别,归一化后得到坐标信息 ,is_random可以修改,默认裁剪中间
img1=os.path.join(rootpath,img)
img1_xml=os.path.join(rootpath,img_xml)
cla, ann = trans_img(img1, img1_xml, classes_dict)
image = tl.vis.read_image(img1)
#裁剪图片
im_crop, clas, coords = tl.prepro.obj_box_crop(image, cla,
ann, wrg=w, hrg=h, is_rescale=True, is_center=True,
is_random=False)
tl.vis.save_image(im_crop,os.path.join(savepath,str(w)+"_"+str(h)+"_crop_"+img))
image = im_crop.copy()
imh, imw = image.shape[0:2]
# clas类别信息 anns新的坐标信息,im_crop裁剪的图片
for i in range(len(coords)):
pos = []
x, y, x2, y2 = tl.prepro.obj_box_coord_centroid_to_upleft_butright(coords[i])
x, y, x2, y2 = tl.prepro.obj_box_coord_scale_to_pixelunit([x, y, x2, y2], (imh, imw))
pos.append(x),pos.append(y),pos.append(x2),pos.append(y2)
anns.append(pos)
###新建xml###########
node_root = Element('annotation')
node_folder = SubElement(node_root, 'folder')
node_folder.text = '1'
node_filename = SubElement(node_root, 'filename')
node_filename.text ="crop_"+img
node_size = SubElement(node_root, 'size')
node_width = SubElement(node_size, 'width')
node_width.text = str(w)
node_height = SubElement(node_size, 'height')
node_height.text = str(h)
node_depth = SubElement(node_size, 'depth')
node_depth.text = '3'
for i in range(len(clas)):
node_object = SubElement(node_root, 'object')
node_name = SubElement(node_object, 'name')
node_name.text = str(classes[clas[i]-1])
node_difficult = SubElement(node_object, 'difficult')
node_difficult.text = '0'
node_bndbox = SubElement(node_object, 'bndbox')
node_xmin = SubElement(node_bndbox, 'xmin')
node_xmin.text = str(anns[i][0])
node_ymin = SubElement(node_bndbox, 'ymin')
node_ymin.text = str(anns[i][1])
node_xmax = SubElement(node_bndbox, 'xmax')
node_xmax.text = str(anns[i][2])
node_ymax = SubElement(node_bndbox, 'ymax')
node_ymax.text = str(anns[i][3])
xml = tostring(node_root, pretty_print=True) # 格式化显示,该换行的换行
img_newxml=os.path.join(savepath,str(w)+"_"+str(h)+"_crop_"+img_xml)
file_object = open(img_newxml, 'wb')
file_object.write(xml)
file_object.close()
if __name__ == "__main__":
# img=r"E:\data_set\ce\201810081620161.jpg"
# img_xml=r"E:\data_set\ce\201810081620161.xml"
file_path=r"E:\data_set\hd.txt"
rootpath=r"E:\data_set\HD_data\delete_small_hd_train"
savepath=r"E:\data_set\HD_data\480_640"
w=480
h=640
classes,classes_dict=get_classdict(file_path)
list = os.listdir(rootpath)
for i in range(0, len(list)):
path = os.path.join(rootpath, list[i])
if os.path.isfile(path):
if (list[i].endswith("jpg") or list[i].endswith("JPG")):
img=list[i]
img_xml=list[i].split(".")[0] + ".xml"
im=Image.open(path)
print(img)
if im.size[0]>w and im.size[1]>h:
crop_img(rootpath, savepath, classes, classes_dict,img, img_xml,w, h)