pytorch fasterrcnn-resnet50-fpn 神经网络 目标识别 应用 —— 推理识别代码讲解(开源)

pytorch fasterrcnn-resnet50-fpn 神经网络 目标识别 应用 —— 推理识别代码讲解(开源)

  • 项目地址
  • 二、推理识别代码讲解
      • 1、加载模型
        • 1)加载网络结构
        • 2)加载权重文件
        • 3)model状态配置
      • 2、图片推理
        • 推理——最最最关键的环节到了!
          • boxes:
          • labels:
          • scores:
        • boxes labels scores 是按照顺序对应的
      • 3、推理结果转换
        • 完整代码

项目地址

完整代码放在文末
https://gitee.com/laomaogu/fasterrcnn_resnet50_fpn_study/tree/master


pytorch fasterrcnn-resnet50-fpn 神经网络 目标识别 应用 —— 推理识别代码讲解(开源)_第1张图片


二、推理识别代码讲解

网上好多案例都是从训练开始讲
说实话,训练的那部分是真枯燥,一个头两个大
_
脑子一抽,跑去搞推理部分的代码,
跑成功的那一刻,成就感瞬间来了
同时这也让我觉得这玩意真tm有意思

所以说,学习是一件好东西,但如果能开心的学习,简直不要太棒!


先回顾上一张的推理流程:

输出
加载模型结构
神经网络
加载权重文件
预 测
待推理数据
推理结果

本章中,我们简化一下

依次序,分为三个流程:

加载模型
图片推理
推理结果转换

以上流程都在 predict() 方法中实现

1、加载模型

首先,来看一下怎么创建模型:

 # 获取当前设备是 CPU 还是 GPU
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # 获取标签信息
    dict_classes, num_classes, category_index = loadtxt(args.classfile_textpath)

    # 创建模型
    print("Creating model")
    model = torchvision.models.detection.__dict__["fasterrcnn_resnet50_fpn"](pretrained=False,
                                                                             num_classes=num_classes + 1)
    # 加载权重文件
    model_state_dict = torch.load(args.weights, map_location=device)["model"]
    model.load_state_dict(model_state_dict)

    # 设备转换到 device 上运行
    model.to(device)
    print("using {} device.".format(device))
    model.eval()  # 进入验证模式

以上是加载模型的全部过程,接下来我们一步一步拆分


1)加载网络结构

所有的准备工作,pytorch 都给我们做好了,只需要轻轻地调用一个接口,就能创建模型结构
如果你像用yptorch提供的其他网络结构,也可以将[“fasterrcnn_resnet50_fpn”] 替换掉。

torchvision.models.detection.__dict__["fasterrcnn_resnet50_fpn"](pretrained=False,num_classes=num_classes + 1)

注意:如果第一次运行,请将 pretrained 参数设为 True, 为什么要这样做? 先不管他

num_classes 参数又是啥?
num_classes :是你要预测的目标的类别数量
假设老板要你做一个项目,检测识别办公室里的猫和狗,类别只有两个:猫’ 狗。 那这个参数就是2

num_classes为什么要 +1 呢?
因为这里还有一个隐藏的类别,background,即背景类,在训练时很重要, 现在只需要记得,num_classes要+1


2)加载权重文件

网络结构是有了,还需要把判断标准 加进去,通过 model.load_state_dict():

 model_state_dict = torch.load(args.weights, map_location=device)["model"]
 model.load_state_dict(model_state_dict)

其中,
args.weights 指的时权重文件的路径
device :设备类型(本次代码运行基础为CPU)

3)model状态配置

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # 获取设备类型
model.to(device) # 模型转换到设备上
 model.eval()  # 进入验证模式

OK! 到这里,模型创建就全部完成!
是不是非常简单
接下来把图片数据放进去就能预测了!距离成功就差一小步了

2、图片推理

主要内容按住不表,先聊聊图片

因为我这个项目是针对图片预测

啥是图片?

举个例子,现在有一个2*2大小的图片,

pytorch fasterrcnn-resnet50-fpn 神经网络 目标识别 应用 —— 推理识别代码讲解(开源)_第2张图片
在神经网络算法中,一切都是数字

四个方框代表四个像素点

每个像素由三个通道组成 , 即R G B(红,绿,蓝);

神经网络不认识**.Jpeg** 数据格式,

所以要将图片转换成能识别的 torch.Tensor 数据格式 :

# 加载图片,并转换为 torch.Tensor 类型
original_img = Image.open(args.img)
convert = torchvision.transforms.ToTensor()
img = convert(original_img)
# 添加一个维度,匹配算法入口
img = torch.unsqueeze(img, dim=0)

同样,pytorch的torchvision包提供了格式转换接口torchvision.transforms.ToTensor()
最后得到的 img ,就是算法能认识的数据,


推理——最最最关键的环节到了!

 with torch.no_grad():  
        print("predict start")
        predictions = model(img.to(device))

首先 torch.no_grad() 表示禁用梯度计算,涉及到训练部分的理论知识,本章不深究, 只需要知道,在预测之前需要调用

就两行关键代码!
是不是非常简单!

呐! predictions

这就是我们的推理结果,打印出来看看是啥:

print(f"result:{predictions}")

[{'boxes': tensor([[1.1527e+03, 4.7917e+02, 1.2506e+03, 5.5152e+02],
        [1.4364e+03, 1.4129e+02, 1.5374e+03, 2.1493e+02],
        [1.4452e+02, 8.6485e+02, 2.9079e+02, 9.6182e+02],
        [1.4117e+02, 1.1672e+03, 2.9326e+02, 1.2621e+03],
        [1.3704e+03, 1.8597e+03, 1.4486e+03, 1.9170e+03],
        [1.4331e+02, 1.3160e+03, 2.8870e+02, 1.4112e+03],
        [1.4920e+03, 1.5612e+03, 1.5750e+03, 1.6150e+03],
        [4.2725e+02, 8.7433e+02, 5.3222e+02, 9.4892e+02],
        [1.4313e+03, 2.5913e+02, 1.5079e+03, 3.1408e+02],
        [1.2542e+02, 5.9658e+02, 2.0788e+02, 6.4756e+02],
        [6.5477e+02, 1.1741e+03, 7.5365e+02, 1.2500e+03],
        [1.1523e+03, 2.2229e+02, 1.2506e+03, 2.9588e+02],
        [4.2621e+02, 1.1732e+03, 5.3022e+02, 1.2498e+03],
        [6.0383e+02, 1.8894e+03, 6.8434e+02, 1.9433e+03],
        [6.5404e+02, 8.7371e+02, 7.5516e+02, 9.4796e+02],
        [9.0214e+02, 8.8163e+02, 9.8600e+02, 9.3581e+02],
        [9.0179e+02, 7.4204e+02, 9.8496e+02, 7.9480e+02],
        [1.9014e+03, 6.7014e+02, 1.9801e+03, 7.2421e+02],
        [6.1216e+02, 1.4772e+03, 6.9884e+02, 1.5292e+03],
        [6.5601e+02, 7.1957e+02, 7.5319e+02, 7.9393e+02],
        [1.7512e+02, 7.2836e+02, 2.5711e+02, 7.8923e+02],
        [4.2725e+02, 1.3287e+03, 5.3333e+02, 1.4029e+03],
        [9.0085e+02, 1.1862e+03, 9.8763e+02, 1.2390e+03],
        [1.4346e+03, 4.5968e+02, 1.5336e+03, 5.3184e+02],
        [6.1542e+02, 1.5719e+03, 6.9611e+02, 1.6299e+03],
        [1.4287e+03, 3.6229e+02, 1.5090e+03, 4.1909e+02],
        [4.2589e+02, 7.1913e+02, 5.2999e+02, 7.9609e+02],
        [6.5392e+02, 1.3289e+03, 7.5261e+02, 1.4011e+03],
        [1.2835e+02, 4.9626e+02, 2.1096e+02, 5.5453e+02],
        [1.9012e+03, 8.5548e+02, 1.9801e+03, 9.1464e+02],
        [1.2152e+03, 7.2554e+02, 1.2945e+03, 9.1673e+02],
        [1.2492e+02, 2.2695e+02, 2.0045e+02, 2.8335e+02],
        [1.6239e+03, 3.6352e+02, 1.7052e+03, 4.1547e+02],
        [1.5204e+03, 8.2876e+02, 1.5964e+03, 1.0081e+03],
        [1.5196e+03, 1.1347e+03, 1.5992e+03, 1.3151e+03],
        [1.2325e+03, 1.4950e+03, 1.3150e+03, 1.6722e+03],
        [1.6248e+03, 2.6157e+02, 1.7052e+03, 3.1244e+02],
        [8.1357e+02, 1.5016e+03, 8.8935e+02, 1.5568e+03],
        [9.0093e+02, 1.3396e+03, 9.8453e+02, 1.3908e+03],
        [1.8980e+03, 7.6460e+02, 1.9767e+03, 8.2006e+02],
        [1.3715e+03, 1.7651e+03, 1.4437e+03, 1.8205e+03],
        [1.0047e+03, 1.5011e+03, 1.0883e+03, 1.5532e+03],
        [1.2173e+03, 1.1976e+03, 1.2944e+03, 1.3846e+03],
        [1.9086e+03, 4.4421e+02, 1.9875e+03, 6.2201e+02],
        [1.9088e+03, 2.1667e+02, 1.9845e+03, 3.9899e+02],
        [1.0413e+03, 1.7980e-01, 1.0904e+03, 4.5562e+01]]), 
        'labels': tensor([16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
        16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
        16, 16, 16, 16, 16, 16, 16, 16, 16, 16]), 
        'scores': tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9999, 0.9998,
        0.9991])}] 

看到里面的boxes labels scores ,
啥意思? 谜底就在谜面上!

boxes:

以左上角为原点,检测对象的box框的位置参数,
举例子,[xmin,ymin, xmax, ymax],表示其中一个box框

pytorch fasterrcnn-resnet50-fpn 神经网络 目标识别 应用 —— 推理识别代码讲解(开源)_第3张图片

labels:

[16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]

这就是推理结果中的类别结果
数字对应着每个box框的类别推理结果
这咋全是数字啊!? 我的类别标签呢??
别着急,其实这就标签,只不过不是我们理解的格式, 对应着 classes.txt 内容,就是我们能理解的格式
接着往下看

scores:

‘scores’: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9999, 0.9998,
0.9991]

score 对应box框的预测识别概率,上面的数据看,每个box框的识别概率都接近100%


boxes labels scores 是按照顺序对应的

比如他们的第一个参数,

对应的是同一个box框的box位置,预测类别标签,预测识别率,

第二个参数,也是对应同一个box框


3、推理结果转换

predictions 这么一大串数字,难道就没有直接一点的展现方式吗 当然有!

大部分情况下我们都会写一点下工具,

将推理结果以其他方式展现,就像本章开头的图片一样

接下来,就是如何做出上述图片的效果:

# 保存为标记图像
resultconvert = result_convert(category_index, args.img, args.savedir, predictions)
# createxml 是否生成 xml文件 createmarkpicture:是否生成标记文件 showing:是否显示标记文件
resultconvert.convert_start(createxml=False, createmarkpicture=True,showimg=True)

完整代码

项目地址:https://gitee.com/laomaogu/fasterrcnn_resnet50_fpn_study/tree/master


Pytorch:1.8.1 + cpu
vision:0.9.1


author: 老毛鸪
email:  maogu123@126.com
gitee:  https://gitee.com/laomaogu
CSDN:   https://blog.csdn.net/qq_42239488?spm=1018.2226.3001.5343
'''


import os
import argparse
import torch
import torchvision
from PIL import Image
from util.pre_util import result_convert


def predict(args):
    # 获取当前设备是 CPU 还是 GPU
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # 获取标签信息
    dict_classes, num_classes, category_index = loadtxt(args.classfile_textpath)

    # 创建模型
    print("Creating model")
    model = torchvision.models.detection.__dict__["fasterrcnn_resnet50_fpn"](pretrained=False,
                                                                             num_classes=num_classes + 1)
    # 加载权重文件
    model_state_dict = torch.load(args.weights, map_location=device)["model"]
    model.load_state_dict(model_state_dict)

    # 设备转换到 device 上运行
    model.to(device)
    print("using {} device.".format(device))
    model.eval()  # 进入验证模式

    # 加载图片,并转换为 torch.Tensor 类型
    original_img = Image.open(args.img)
    convert = torchvision.transforms.ToTensor()
    img = convert(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    with torch.no_grad():
        print("predict start")
        predictions = model(img.to(device))[0]
        print("predict end")
        print(f"result:{predictions}")
        
        # 保存为标记图像
        resultconvert = result_convert(category_index, args.img, args.savedir, predictions)
        # createxml 是否生成 xml文件 createmarkpicture:是否生成标记文件 showing:是否显示标记文件
        resultconvert.convert_start(createxml=False, createmarkpicture=True,showimg=True)

def loadtxt(classpath):
    assert os.path.exists(classpath), ".classpath: {0} file does not exist...".format(classpath)
    class_num = 0
    class_dict = dict()
    with open(classpath, 'r') as f:
        classlist = f.readlines()
        class_num = len(classlist)
        for cls in classlist:
            class_dict[cls.strip('\n')] = classlist.index(cls) + 1
        category_index = {v: k for k, v in class_dict.items()}
    return class_dict, class_num, category_index


def get_args_parser(add_help=True):

    parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)

    # 权重文件路径
    parser.add_argument("--weights", default="./weight_model/model_20220810_998996_981979.pth", type=str,
                        help="the weights enum name to load")
    # 标签文件地址
    parser.add_argument("--classfile-textpath", default="./DATA/classes.txt",
                        type=str, help="class text file path")

    return parser

if __name__ == "__main__":
    args = get_args_parser().parse_args()
    predict(args)

你可能感兴趣的:(记录,pytorch,神经网络,开源)