目录
1. 加载必要的包
2. 预处理函数
3. 加载队列
4. 可视化
5. 测试
文件名:get_dataset.py
本文件只是整个YOLO项目中的一个,要加载一个xml_parase.py文件
# -*- coding: utf-8 -*-
from xml_parse import paras_annotation
import os
import matplotlib.pyplot as plt
from matplotlib import patches
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
TensorFlow版本2.0.0及以上
主要是加载训练集路径后读取图片,解码,返回图片三通道数据信息
def preprocess(img, img_boxes):
# img: string
# img_boxes: [40,5]
x = tf.io.read_file(img)
x = tf.image.decode_png(x, channels=3)
x = tf.image.convert_image_dtype(x, tf.float32) # 将数据转化为 =>[0~ 1]
return x, img_boxes
def get_datasets(img_dir, ann_dir,label,batch_size=1):
imgs, boxes = paras_annotation(img_dir, ann_dir, label)
db = tf.data.Dataset.from_tensor_slices((imgs, boxes))
db = db.shuffle(1000).map(preprocess).batch(batch_size=batch_size).repeat()
return db
def db_visualize(db):
"""
可视化
:param db:
:return:
"""
# imgs:[b, 512, 512, 3]
# imgs_boxes: [b, 40, 5]
imgs, imgs_boxes = next(iter(db))
img, img_boxes = imgs[0], imgs_boxes[0]
f,ax1 = plt.subplots(1)
# display the image, [512,512,3]
ax1.imshow(img)
for x1,y1,x2,y2,l in img_boxes: # [40,5]
x1,y1,x2,y2 = float(x1), float(y1), float(x2), float(y2)
w = x2 - x1
h = y2 - y1
if l==1: # green for sugarweet
color = (0,1,0)
elif l==2: # red for weed
color = (1,0,0) # (R,G,B)
else: # ignore invalid boxes
break
rect = patches.Rectangle((x1,y1), w, h, linewidth=2,
edgecolor=color, facecolor='none')
ax1.add_patch(rect)
plt.show()
# 测试代码
# if __name__ == "__main__":
# img_path = "data\\val\\image"
# annotation_path = "data\\val\\annotation"
# label = ("sugarbeet", "weed")
# train_db = get_datasets(img_dir=img_path, ann_dir=annotation_path,
# label=label)
# train_db = augmentation_generator(train_db)
# db_visualize(train_db)
未完待续。。。。。。