对B站up同济子豪兄的图像分类系列的学习(大佬的完整代码在GitHub开源)
pip install numpy pandas matplotlib requests tqdm opencv-python pillow -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
不过我是在本地进行测试,无cpu,因此pytorch的安装命令需要改变。(事实上早先已经配置了)
关于pytorch的安装命令:PyTorch
此外,我们还需要安装mmcv-full库 ,mmcv是openmmlab算法体系的重要组成部分。在本次任务中我们使用mmcv-full对视频进行读取与处理。
安装
MMCV 有两个版本:
- mmcv-full: 完整版,包含所有的特性以及丰富的开箱即用的 CPU 和 CUDA 算子。
- mmcv: 精简版,不包含 CPU 和 CUDA 算子但包含其余所有特性和功能,类似 MMCV 1.0 之前的版本。如果你不需要使用算子的话,精简版可以作为一个考虑选项。
注意: 请不要在同一个环境中安装两个版本,否则可能会遇到类似
ModuleNotFound
的错误。在安装一个版本之前,需要先卸载另一个。如果 CUDA 可用,强烈推荐安装 mmcv-full
。
在安装 mmcv-full 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch 官方安装文档。
pip install -U openmim
mim install mmcv-full
使用mim安装mmcv-full的好处是可以自动匹配自身环境,若不愿安装openmim可以前往openmmlab的GitHub仓库文档查找对应pytorch和cuda版本的pip命令。
由于我的模型是在云平台上完成的,因此需要将云平台上的文件传输到本地。我使用的云平台(恒源云)支持了OSS、FileZilla、Xftp等传输工具,同时提供了绑定网盘的接口,支持百度云盘和阿里云盘的上传下载,实例中的数据可以免费保存24小时。
由于需要下载的文件都比较小,我直接采用了jupyternotebook自带的上传下载功能。
映射字典是npy格式,如果像我一样在云平台上训练模型的朋友记得保存。
模型权重是准确率达到0.96的权重,这是由上次的训练得到的。
测试图片是采用任务一中用到的数据收集代码在百度图片中爬取的。
首先导入需要的包
import torch
import torchvision
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
然后使用PIL库读取图片,并进行预处理。预处理的四步是缩放、裁剪、转 Tensor、归一化,此处不再赘述。此外还需要读取映射字典。
注意:在gpu中训练得到的权重文件不能直接用于cpu,需要通过加入参数map_location='cpu'进行映射,如右下图所示。
from PIL import Image, ImageFont, ImageDraw
# 导入中文字体,指定字号
font = ImageFont.truetype('SimHei.ttf', 22)
#载入字典映射
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
随后执行前向预测,得到所有类别的 logit 预测分数。logits是Odds取对数,本身没有上下限,对建模方便,但并不是一个概率,因此需要使用softmax层将其转变为“概率的模样”,进行归一化。
# 执行前向预测,得到所有类别的 logit 预测分数
pred_logits = model(input_img)
# 对 logit 分数做 softmax 运算
pred_softmax = F.softmax(pred_logits, dim=1)
然后像之前做过的一样,取置信度最大的十个结果绘制在图上。
尽管这个图片的分类效果还算不错,但由于本身的数据集并不含有许多零件混合的情况,因此在分类方面这个模型依然有些力不从心。或许面对这种任务更应当使用语义分割的方法。
与单张图像预测的区别在于需要按帧读取、写入并输出结果,因此多使用了shutil、tempfile、mmcv三个库。
除此之外大部分代码与预测单张图片相同。
# 创建临时文件夹,存放每帧结果
temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))
# 读入待预测视频
imgs = mmcv.VideoReader(input_video)
prog_bar = mmcv.ProgressBar(len(imgs))
# 对视频逐帧处理
for frame_id, img in enumerate(imgs):
## 处理单帧画面
img, pred_softmax = pred_single_frame(img, n=5)
# 将处理后的该帧画面图像文件,保存至 /tmp 目录下
cv2.imwrite(f'{temp_out_dir}/{frame_id:06d}.jpg', img)
prog_bar.update() # 更新进度条
# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'output/output_pred.mp4', fps=imgs.fps, fourcc='mp4v')
shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)
pred_single_frame预测单帧图像的函数与之前相同。
同时我们也可以将置信度的变化柱状图加入到生成的视频中,也就是将这样的图片连成视频,操作上并无不同,同济子豪兄大佬的开源代码中(在本文开头)已经有详细的解释,我就不再赘述。
与之前大同小异,加入读取摄像头每一帧的部分即可。opencv中带有这样的函数。
随后处理并得到写入预测结果的数组,通过cv2.imshow()即可展示。
# 调用摄像头逐帧实时处理模板
# 不需修改任何代码,只需修改process_frame函数即可
# 同济子豪兄 2021-7-8
# 导入opencv-python
import cv2
import time
# 处理帧函数
def process_frame(img):
# 记录该帧开始处理的时间
start_time = time.time()
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR转RGB
img_pil = Image.fromarray(img_rgb) # array 转 PIL
input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
pred_logits = model(input_img) # 执行前向预测,得到所有类别的 logit 预测分数
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
top_n = torch.topk(pred_softmax, 5) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析预测类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析置信度
# 使用PIL绘制中文
draw = ImageDraw.Draw(img_pil)
# 在图像上写字
for i in range(len(confs)):
pred_class = idx_to_labels[pred_ids[i]]
text = '{:<15} {:>.3f}'.format(pred_class, confs[i])
# 文字坐标,中文字符串,字体,bgra颜色
draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))
img = np.array(img_pil) # PIL 转 array
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # RGB转BGR
# 记录该帧处理完毕的时间
end_time = time.time()
# 计算每秒处理图像帧数FPS
FPS = 1/(end_time - start_time)
# 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
img = cv2.putText(img, 'FPS '+str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4, cv2.LINE_AA)
return img
# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(0)
# 打开cap
cap.open(0)
# 无限循环,直到break被触发
while cap.isOpened():
# 获取画面
success, frame = cap.read()
if not success:
print('Error')
break
## !!!处理帧函数
frame = process_frame(frame)
# 展示处理后的三通道图像
cv2.imshow('my_window',frame)
if cv2.waitKey(1) in [ord('q'),27]: # 按键盘上的q或esc退出(在英文输入法下)
break
# 关闭摄像头
cap.release()
# 关闭图像窗口
cv2.destroyAllWindows()