目标检测——可视化拟合真实框

可视化拟合真实框

import cv2
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# 定义边框回归损失函数
def bbox_loss(y_true, y_pred):
    # 计算坐标和宽高的平方差
    loss = F.mse_loss(y_true, y_pred, reduction='mean')
    return loss
def draw_img(bbox_true,bbox_pred,save_path,title):
    img = np.ones((200, 200, 3), dtype=np.uint8)*255
    cv2.rectangle(img, tuple(bbox_true[:2]), tuple(bbox_true[2:]), (0, 255, 0), 1)
    cv2.rectangle(img, tuple(bbox_pred[:2]), tuple(bbox_pred[2:]), (0, 0, 255), 1)
    plt.clf()
    plt.imshow(img)
    plt.title(fr"{title}")
    plt.savefig(save_path)
# 随机生成一个真实框和一个预测框
bbox_true = np.array([50, 50, 100, 100])  # 真实框(左上角,右下角)
bbox_pred = np.array([20, 20, 10, 80], dtype=np.float32)  # 预测框,使用dtype指定为float32
images = []
# 定义可训练的参数
bbox_pred = torch.tensor(bbox_pred, requires_grad=True)

# 定义优化器
optimizer = torch.optim.SGD([bbox_pred], lr=0.02)

# 训练模型
for epoch in range(900):
    # 将框坐标转换为张量
    bbox_true_tensor = torch.tensor(bbox_true).float()


    # 计算损失
    loss = bbox_loss(bbox_true_tensor, bbox_pred)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()

    # 更新参数
    optimizer.step()

    # 打印损失和预测框
    print(f'Epoch {epoch+1}: Loss = {loss.item():.4f}, BBox_pred = {bbox_pred}')
    if epoch % 10 == 0:
        filename = fr"img/{epoch}.png"
        draw_img(bbox_true,(bbox_pred.detach().numpy()+0.5).astype(np.uint8),filename,f"epoch:{epoch},loss:{loss.item():.4f}")
        im = Image.open(filename)
        images.append(im)

images[0].save('animation.gif', save_all=True, append_images=images[1:], duration=100, loop=0)

如下图所示,绿色框为真实框,蓝色框为预测框。这里采用MSE loss去不断更新预测框的参数从而不断逼近真实框。

目标检测——可视化拟合真实框_第1张图片

你可能感兴趣的:(目标检测,目标检测,深度学习,python)