(202301)pytorch图像分类全流程实战Task4:新图片新视频预测

Task4:新图片新视频预测

(202301)pytorch图像分类全流程实战Task4:新图片新视频预测_第1张图片

对B站up同济子豪兄的图像分类系列的学习(大佬的完整代码在GitHub开源)  

安装配置环境

本次任务中用到的环境如下

pip install numpy pandas matplotlib requests tqdm opencv-python pillow -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

不过我是在本地进行测试,无cpu,因此pytorch的安装命令需要改变。(事实上早先已经配置了)

关于pytorch的安装命令:PyTorch

此外,我们还需要安装mmcv-full库 ,mmcv是openmmlab算法体系的重要组成部分。在本次任务中我们使用mmcv-full对视频进行读取与处理。

mmcv-full的安装

安装

MMCV 有两个版本:

  • mmcv-full: 完整版,包含所有的特性以及丰富的开箱即用的 CPU 和 CUDA 算子。
  • mmcv: 精简版,不包含 CPU 和 CUDA 算子但包含其余所有特性和功能,类似 MMCV 1.0 之前的版本。如果你不需要使用算子的话,精简版可以作为一个考虑选项。

注意: 请不要在同一个环境中安装两个版本,否则可能会遇到类似 ModuleNotFound 的错误。在安装一个版本之前,需要先卸载另一个。如果 CUDA 可用,强烈推荐安装 mmcv-full

在安装 mmcv-full 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档。

pip install -U openmim
mim install mmcv-full

使用mim安装mmcv-full的好处是可以自动匹配自身环境,若不愿安装openmim可以前往openmmlab的GitHub仓库文档查找对应pytorch和cuda版本的pip命令。

 准备需要用到的文件

由于我的模型是在云平台上完成的,因此需要将云平台上的文件传输到本地。我使用的云平台(恒源云)支持了OSS、FileZilla、Xftp等传输工具,同时提供了绑定网盘的接口,支持百度云盘和阿里云盘的上传下载,实例中的数据可以免费保存24小时。

由于需要下载的文件都比较小,我直接采用了jupyternotebook自带的上传下载功能。

(202301)pytorch图像分类全流程实战Task4:新图片新视频预测_第2张图片

映射字典是npy格式,如果像我一样在云平台上训练模型的朋友记得保存。

模型权重是准确率达到0.96的权重,这是由上次的训练得到的。

测试图片是采用任务一中用到的数据收集代码在百度图片中爬取的。

预测新图像

首先导入需要的包

import torch
import torchvision
import torch.nn.functional as F

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

然后使用PIL库读取图片,并进行预处理。预处理的四步是缩放、裁剪、转 Tensor、归一化,此处不再赘述。此外还需要读取映射字典。

注意:在gpu中训练得到的权重文件不能直接用于cpu,需要通过加入参数map_location='cpu'进行映射,如右下图所示。

from PIL import Image, ImageFont, ImageDraw
# 导入中文字体,指定字号
font = ImageFont.truetype('SimHei.ttf', 22)

#载入字典映射
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()

 随后执行前向预测,得到所有类别的 logit 预测分数。logits是Odds取对数,本身没有上下限,对建模方便,但并不是一个概率,因此需要使用softmax层将其转变为“概率的模样”,进行归一化。

# 执行前向预测,得到所有类别的 logit 预测分数
pred_logits = model(input_img) 

# 对 logit 分数做 softmax 运算
pred_softmax = F.softmax(pred_logits, dim=1) 

然后像之前做过的一样,取置信度最大的十个结果绘制在图上。

(202301)pytorch图像分类全流程实战Task4:新图片新视频预测_第3张图片

尽管这个图片的分类效果还算不错,但由于本身的数据集并不含有许多零件混合的情况,因此在分类方面这个模型依然有些力不从心。或许面对这种任务更应当使用语义分割的方法。

预测视频文件

与单张图像预测的区别在于需要按帧读取、写入并输出结果,因此多使用了shutil、tempfile、mmcv三个库。

除此之外大部分代码与预测单张图片相同。

# 创建临时文件夹,存放每帧结果
temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))


# 读入待预测视频
imgs = mmcv.VideoReader(input_video)

prog_bar = mmcv.ProgressBar(len(imgs))

# 对视频逐帧处理
for frame_id, img in enumerate(imgs):
    
    ## 处理单帧画面
    img, pred_softmax = pred_single_frame(img, n=5)

    # 将处理后的该帧画面图像文件,保存至 /tmp 目录下
    cv2.imwrite(f'{temp_out_dir}/{frame_id:06d}.jpg', img)
    
    prog_bar.update() # 更新进度条

# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'output/output_pred.mp4', fps=imgs.fps, fourcc='mp4v')

shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)

pred_single_frame预测单帧图像的函数与之前相同。

同时我们也可以将置信度的变化柱状图加入到生成的视频中,也就是将这样的图片连成视频,操作上并无不同,同济子豪兄大佬的开源代码中(在本文开头)已经有详细的解释,我就不再赘述。

 预测摄像头实时画面

与之前大同小异,加入读取摄像头每一帧的部分即可。opencv中带有这样的函数。

随后处理并得到写入预测结果的数组,通过cv2.imshow()即可展示。

# 调用摄像头逐帧实时处理模板
# 不需修改任何代码,只需修改process_frame函数即可
# 同济子豪兄 2021-7-8

# 导入opencv-python
import cv2
import time


# 处理帧函数
def process_frame(img):
    
    # 记录该帧开始处理的时间
    start_time = time.time()
    
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR转RGB
    img_pil = Image.fromarray(img_rgb) # array 转 PIL
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测,得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
    
    top_n = torch.topk(pred_softmax, 5) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析预测类别
    confs = top_n[0].cpu().detach().numpy().squeeze() # 解析置信度
    
    # 使用PIL绘制中文
    draw = ImageDraw.Draw(img_pil) 
    # 在图像上写字
    for i in range(len(confs)):
        pred_class = idx_to_labels[pred_ids[i]]
        text = '{:<15} {:>.3f}'.format(pred_class, confs[i])
        # 文字坐标,中文字符串,字体,bgra颜色
        draw.text((50, 100 + 50 * i),  text, font=font, fill=(255, 0, 0, 1))
    img = np.array(img_pil) # PIL 转 array
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # RGB转BGR
    
    # 记录该帧处理完毕的时间
    end_time = time.time()
    # 计算每秒处理图像帧数FPS
    FPS = 1/(end_time - start_time)  
    # 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
    img = cv2.putText(img, 'FPS  '+str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4, cv2.LINE_AA)
    return img


# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(0)

# 打开cap
cap.open(0)

# 无限循环,直到break被触发
while cap.isOpened():
    # 获取画面
    success, frame = cap.read()
    if not success:
        print('Error')
        break
    
    ## !!!处理帧函数
    frame = process_frame(frame)
    
    # 展示处理后的三通道图像
    cv2.imshow('my_window',frame)

    if cv2.waitKey(1) in [ord('q'),27]: # 按键盘上的q或esc退出(在英文输入法下)
        break
    
# 关闭摄像头
cap.release()

# 关闭图像窗口
cv2.destroyAllWindows()

你可能感兴趣的:(参与dw开源学习,pytorch,分类,深度学习)