目标检测——yoloV3案例

目录

  • 数据获取
  • TFrecord文件
    • 什么是TFrecord文件
    • 将数据转换成TFrecord文件
    • 读取TFrecord文件
    • 数据处理
  • 模型构建
  • 模型训练
    • 损失函数的计算
    • 正负样本的设定
    • 模型训练
      • 获取数据集
      • 加载模型
      • 模型训练
  • 模型预测

数据获取

目标检测——yoloV3案例_第1张图片
labellmage使用方法
目标检测——yoloV3案例_第2张图片
目标检测——yoloV3案例_第3张图片

TFrecord文件

目标检测——yoloV3案例_第4张图片

什么是TFrecord文件

目标检测——yoloV3案例_第5张图片
目标检测——yoloV3案例_第6张图片

将数据转换成TFrecord文件

目标检测——yoloV3案例_第7张图片

from dataset.vocdata_tfrecord import load_labels,write_to_tfrecord
#1
datapath='./VOCdevkit/VOC2007/'
#2
all_xml=load_labels(datapath,'train')
#3
tfrecord_path='./yolov3/dataset/voc_train.tfrecords'
#4
img_path=os.path.join(datapath,'JPEGImages')
#5
write_to_tfrecord(all_xml,tfrecord_path,img_path)

读取TFrecord文件

from dataset.get_tfdata import getdata
dataset=getdata('./dataset/voc_val.tfrecords')
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

#数据展示
#1获取类别信息
from utils.config_utils import read_class_names
classes=read_class_names('config/classname')
#2创建画布
plt.figure(15,10)
#3获取数据遍历
i=0
for image,width,height,boxes,boxes_category in datasets.take(3):
    #4划分不同的坐标轴
    plt.subplot(1,3,i+1)
    #5显示图像:plt.imshow()
    plt.imshow(image)
    #6显示box,遍历所有的bbox,rectange进行绘制
    ax=plt.gca()
    for j in range(boxes.shape[0]):
        rect=Rectangle((boxes[j,0],boxes[j,1]),boxes[j,2] -boxes[j,0],boxes[j,3]-boxes[j,1],color='r',fill=False)
        ax.add_patch(rect)
        #7显示类别
        label_id=boxes_category[j]
        label=classes.get(label_id.numpy())
        ax.text(boxes[j,0],boxes[j,1]+8,label,color='w',size=11,backgroundcolor='none')
    i+=1
plt.show()

目标检测——yoloV3案例_第8张图片

数据处理

目标检测——yoloV3案例_第9张图片

from dataset.preprocess import preprocess
#2创建画布
plt.figure(15,10)
#3获取数据遍历
i=0
for image,width,height,boxes,boxes_category in datasets.take(3):
    #进行数据处理
    image,boxes=preprocess(image,boxes)
    #4划分不同的坐标轴
    plt.subplot(1,3,i+1)
    #5显示图像:plt.imshow()
    plt.imshow(image)
    #6显示box,遍历所有的bbox,rectange进行绘制
    ax=plt.gca()
    for j in range(boxes.shape[0]):
        rect=Rectangle((boxes[j,0],boxes[j,1]),boxes[j,2] -boxes[j,0],boxes[j,3]-boxes[j,1],color='r',fill=False)
        ax.add_patch(rect)
        #7显示类别
        label_id=boxes_category[j]
        label=classes.get(label_id.numpy())
        ax.text(boxes[j,0],boxes[j,1]+8,label,color='w',size=11,backgroundcolor='none')
    i+=1
plt.show()

目标检测——yoloV3案例_第10张图片

模型构建

目标检测——yoloV3案例_第11张图片

模型训练

from model.yolov3 import YOLOv3 
yolov3=YOLOv3((416,416,3),80)
yolov3.summary()

损失函数的计算

目标检测——yoloV3案例_第12张图片

from core.loss import Loss
yolov3_loss=Loss((416,416,3),80)

正负样本的设定


目标检测——yoloV3案例_第13张图片

from core.bbox_target import bbox_to_target
#获取数据进行目标值设置
for image,width,height,boxes,labels in dataset.take[1]:
    #获取目标值
    label1,label2,label3=bbox_to_target(boxes,label,num_classes=20)
import tensorflow as tf
#获取正样本索引
tf.where(tf.equal(label[...,4],1))
#坐标值
label1[12,12,0,0:4]
label1[12,12,0,5:]

模型训练

目标检测——yoloV3案例_第14张图片

获取数据集

from dataset.preprocess import dataset
batch_size=1
trainset=dataset('dataset/voc_train.tfrecords',batch_size)

加载模型

from model.yoloV3 import YOLOv3
yolov3=YOLOv3((416,416,3),20)

from core.loss import Loss
yoloV3_loss=Loss((416,416,3),20)

模型训练

目标检测——yoloV3案例_第15张图片
目标检测——yoloV3案例_第16张图片

#1定义优化器
optimizer=tf.keras.optimizers.SGD(learning_rate=0.1,momentum=0.9)
#2设置epoch
for epoch in range(2):
    for (batch,inputs) in enumerate(trainset):
        images,labels=inputs
        #3计算损失函数进行参数更新
        #3.1定义上下文缓解
        with tf.GradientTape() as Tape:
            #3.2将图像送入网络中
            outputs=yoloV3(image)
            #3.3计算损失函数
            loss=yoloV3_loss([*outputs,*labels])
            #3.4计算梯度
            grads=Tape.gradient(loss,yolov3.trainable_variables)
            #3.5梯度更新
            optimizer.apply_gradients(zip(grads,yolov3.trainable_variables))
            print(loss)
yolov3.save('yolov3.h5')

模型预测

目标检测——yoloV3案例_第17张图片

#1
img=cv2.imread('image.jpg')
#2
predicter=Predict(class_num=80,yolov3='weights/yolov3.h5')
#3
boundings=predicter(img)
#4
plt.imshow(img[:,:,::-1])

目标检测——yoloV3案例_第18张图片
目标检测——yoloV3案例_第19张图片

你可能感兴趣的:(tensorflow解决cv,目标检测,YOLO,人工智能)