目标检测YOLOV2-训练集数据预处理(二)

目录

1. 加载必要的包

2. 预处理函数

3. 加载队列

4. 可视化

5. 测试


1. 加载必要的包

文件名:get_dataset.py

本文件只是整个YOLO项目中的一个,要加载一个xml_parase.py文件

# -*- coding: utf-8 -*-
from xml_parse import paras_annotation
import os
import matplotlib.pyplot as plt
from matplotlib import patches
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf

 TensorFlow版本2.0.0及以上

2. 预处理函数

主要是加载训练集路径后读取图片,解码,返回图片三通道数据信息

def preprocess(img, img_boxes):
    # img: string
    # img_boxes: [40,5]
    x = tf.io.read_file(img)
    x = tf.image.decode_png(x, channels=3)
    x = tf.image.convert_image_dtype(x, tf.float32) # 将数据转化为 =>[0~ 1]

    return x, img_boxes

3. 加载队列

def get_datasets(img_dir, ann_dir,label,batch_size=1):
    imgs, boxes = paras_annotation(img_dir, ann_dir, label)
    db = tf.data.Dataset.from_tensor_slices((imgs, boxes))
    db = db.shuffle(1000).map(preprocess).batch(batch_size=batch_size).repeat()
    return db

4. 可视化

def db_visualize(db):
    """
    可视化
    :param db:
    :return:
    """
    # imgs:[b, 512, 512, 3]
    # imgs_boxes: [b, 40, 5]
    imgs, imgs_boxes = next(iter(db))
    img, img_boxes = imgs[0], imgs_boxes[0]

    f,ax1 = plt.subplots(1)
    # display the image, [512,512,3]
    ax1.imshow(img)
    for x1,y1,x2,y2,l in img_boxes: # [40,5]
        x1,y1,x2,y2 = float(x1), float(y1), float(x2), float(y2)
        w = x2 - x1
        h = y2 - y1

        if l==1: # green for sugarweet
            color = (0,1,0)
        elif l==2: # red for weed
            color = (1,0,0) # (R,G,B)
        else: # ignore invalid boxes
            break

        rect = patches.Rectangle((x1,y1), w, h, linewidth=2,
        edgecolor=color, facecolor='none')
        ax1.add_patch(rect)
    plt.show()

5. 测试

# 测试代码
# if __name__ == "__main__":
#     img_path = "data\\val\\image"
#     annotation_path = "data\\val\\annotation"
#     label = ("sugarbeet", "weed")
#     train_db = get_datasets(img_dir=img_path, ann_dir=annotation_path,
#                  label=label)
#     train_db = augmentation_generator(train_db)
#     db_visualize(train_db)

未完待续。。。。。。

你可能感兴趣的:(深度学习,tensorflow)