PyTorch图像分类实战(Datawhale)Task2:预训练模型预测

预训练模型预测

参考资料:

  1. 同济子豪兄教学视频:https://space.bilibili.com/1900783/channel/collectiondetail?sid=606800(P2)
  2. 项目代码:https://github.com/TommyZihao/Train_Custom_Dataset

预训练模型预测

  • **预训练模型预测**
    • 1. 图像检测
      • 1.1. 图像预处理
      • 1.2. 图像类别预测
      • 1.3. 预测结果可视化
    • 2. 视频检测
      • 2.1. 检测预处理
      • 2.2. 检测处理方案
    • 3. 摄像头实时画面检测
      • 3.1. 获取摄像头的一帧画面
      • 3.2. 调用摄像头获取每帧(模板)

本章节采用Torchvision中models内包含的已经在ImageNet 1000数据集训练好的模型文件进行图像分类操作,具体包括图像、视频以及摄像头的实时操作。

1. 图像检测

采用resnet18模型进行图片分类预测,包含图像预处理、前向预测、函数激活以及预测结果可视化几个步骤。

1.1. 图像预处理

1. 采用pillow进行图像读取:

# 用 pillow 载入图像
img_path = 'test_img/kangaroo.jpg'
img_pil = Image.open(img_path) 

2. 图像预处理-RCTN:
缩放裁剪、转 Tensor、归一化
Torch是对于Tensor(张量)进行处理,所以需要将图像文件array转为tensor(张量)。

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

1.2. 图像类别预测

对预处理的图像进行前向预测,得到对应类别置信度后进行softmax激活至0-1.

# 图像预处理
input_img = test_transform(img_pil).unsqueeze(0).to(device) 
# 执行前向预测,得到所有类别的 logit 预测分数
pred_logits = model(input_img) 
# 对 logit 分数做 softmax 运算
pred_softmax = F.softmax(pred_logits, dim=1) 

1.3. 预测结果可视化

预测结果是对于1000个类别的置信度分数,需要通过筛选和查询确定预测类别,并进行显示。
结果可视化涉及到中英文字符问题,Opencv无法显示中文字符,而matplotlib显示中文字符需要进行以下设置:

# # windows操作系统
plt.rcParams['font.sans-serif']=['SimHei']  # 用来正常显示中文标签 
plt.rcParams['axes.unicode_minus']=False  # 用来正常显示负号


# Mac操作系统,参考 https://www.ngui.cc/51cto/show-727683.html
# 下载 simhei.ttf 字体文件
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf
#

# Linux操作系统,例如 云GPU平台:https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1
# 如果遇到 SSL 相关报错,重新运行本代码块即可
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf -O /environment/miniconda3/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf/SimHei.ttf
!rm -rf /home/featurize/.cache/matplotlib



import matplotlib
matplotlib.rc("font",family='SimHei') # 中文字体
  • 显示各类别置信度柱状图

采用matplotlib绘制类别置信度柱状图:

#显示各类别置信度柱状图*
plt.figure(figsize=(8,4))

x = range(1000)
y = pred_softmax.cpu().detach().numpy()[0]

ax = plt.bar(x, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
plt.ylim([0, 1.0]) # y轴取值范围
# plt.bar_label(ax, fmt='%.2f', fontsize=15) # 置信度数值

plt.title(img_path, fontsize=30)
plt.xlabel('类别', fontsize=20)
plt.ylabel('置信度', fontsize=20)
plt.tick_params(labelsize=16) # 坐标文字大小

plt.show()

PyTorch图像分类实战(Datawhale)Task2:预训练模型预测_第1张图片

  • 筛选置信度最大的 n 个结果
    根据需要筛选需要的预选目标数量,进行显示:
n = 10
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
  • 预测结果表格输出
    将预测结果通过Pands写入表格进行存储;
pred_df = pd.DataFrame() # 预测结果表格
for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 获取类别名称
    label_idx = int(pred_ids[i]) # 获取类别号
    wordnet = idx_to_labels[pred_ids[i]][0] # 获取 WordNet
    confidence = confs[i] * 100 # 获取置信度
    pred_df = pred_df.append({'Class':class_name, 'Class_ID':label_idx, 'Confidence(%)':confidence, 'WordNet':wordnet}, ignore_index=True) # 预测结果表格添加一行
display(pred_df) # 展示预测结果表格

PyTorch图像分类实战(Datawhale)Task2:预训练模型预测_第2张图片

  • 预测结果写入图片
    英文文本可以通过opencv载入写入:
# 用 opencv 载入原图
img_bgr = cv2.imread(img_path)

for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 获取类别名称
    confidence = confs[i] * 100 # 获取置信度
    text = '{:<15} {:>.4f}'.format(class_name, confidence)
    print(text)
    
    # !图片,添加的文字,左上角坐标,字体,字号,bgr颜色,线宽
    img_bgr = cv2.putText(img_bgr, text, (25, 50 + 40 * i), cv2.FONT_HERSHEY_SIMPLEX, 1.25, (0, 0, 255), 3)


中文文本可以通过pillow载入写入:

draw = ImageDraw.Draw(img_pil)

for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 获取类别名称
    confidence = confs[i] * 100 # 获取置信度
    text = '{:<15} {:>.4f}'.format(class_name, confidence)
    print(text)
    
    # 文字坐标,中文字符串,字体,rgba颜色
    draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))

2. 视频检测

2.1. 检测预处理

视频检测即针对于视频中的每一帧进行图像检测操作,完成操作后将每一帧图像整合为一个新的视频文件。
在每一帧检测过程中,后端绘图,不显示,只保存。
在进行图像检测前,需要采用mmcv进行视频帧读取操作,并进行文件格式转化;
视频格式转化:BGR 转 RGB
图片格式转化:array 转 pil

  	'''
    输入摄像头画面bgr-array,输出前n个图像分类预测结果的图像bgr-array
    '''
    img_bgr = img
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR 转 RGB
    img_pil = Image.fromarray(img_rgb) # array 转 pil
    
    ## 图像检测
    # 图像预处理
    input_img = test_transform(img_pil).unsqueeze(0).to(device) 
    # 执行前向预测,得到所有类别的 logit 预测分数
    pred_logits = model(input_img) 
    # 对 logit 分数做 softmax 运算
    pred_softmax = F.softmax(pred_logits, dim=1) 
    
    
    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度

	# 在图像上写字
    for i in range(n):
        class_name = idx_to_labels[pred_ids[i]][1] # 获取类别名称
        confidence = confs[i] * 100 # 获取置信度
        text = '{:<15} {:>.4f}'.format(class_name, confidence)

        # !图片,添加的文字,左上角坐标,字体,字号,bgr颜色,线宽
        img_bgr = cv2.putText(img_bgr, text, (25, 50 + 40 * i), cv2.FONT_HERSHEY_SIMPLEX, 1.25, (0, 0, 255), 3)

2.2. 检测处理方案

检测输出视频包括原始视频图像以及预测结果的文字信息;

1. 创建临时文件夹,存放每帧结果;

# 创建临时文件夹,存放每帧结果
temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))

2. 读入待预测视频,并逐帧处理
采用mmcv进行视频操作:

# 读入待预测视频
imgs = mmcv.VideoReader(input_video)

prog_bar = mmcv.ProgressBar(len(imgs))

# 对视频逐帧处理
for frame_id, img in enumerate(imgs):
    
    ## 处理单帧画面
    img, pred_softmax = pred_single_frame(img, n=5)
    # 将处理后的该帧画面图像文件,保存至 /tmp 目录下
    cv2.imwrite(f'{temp_out_dir}/{frame_id:06d}.jpg', img)
    prog_bar.update() # 更新进度条

3. 处理单帧文件
通过mmcv把每一帧串成视频文件,并删除每帧暂存文件。

# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'output/output_pred.mp4', fps=imgs.fps, fourcc='mp4v')

# 删除存放每帧画面的临时文件夹
shutil.rmtree(temp_out_dir) 
print('删除临时文件夹', temp_out_dir)

3. 摄像头实时画面检测

摄像头实时画面检测核心内容仍然是图像检测,附加处理为文件读取以及格式类型转化。

3.1. 获取摄像头的一帧画面

## 获取摄像头的一帧画面

# 导入opencv-python
import cv2
import time

# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)

# 打开cap
cap.open(0)
time.sleep(1)
success, img_bgr = cap.read()
    
# 关闭摄像头
cap.release()

# 关闭图像窗口
cv2.destroyAllWindows()

图像信息预览

img_bgr.shape
# (720, 1280, 3)

# BGR转RGB
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

# 图像预览(pillow)
img_pil = Image.fromarray(img_rgb)
img_pil

3.2. 调用摄像头获取每帧(模板)

通过成熟模板化函数代码,调用摄像头进行检测;

  • 调用摄像头逐帧实时处理模板
# 调用摄像头逐帧实时处理模板
# 不需修改任何代码,只需修改process_frame函数即可
# 同济子豪兄 2021-7-8

# 导入opencv-python
import cv2
import time

# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)

# 打开cap
cap.open(0)

# 无限循环,直到break被触发
while cap.isOpened():
    # 获取画面
    success, frame = cap.read()
    if not success:
        print('Error')
        break
    
    ## !!!处理帧函数
    frame = process_frame(frame)
    
    # 展示处理后的三通道图像
    cv2.imshow('my_window',frame)

    if cv2.waitKey(1) in [ord('q'),27]: # 按键盘上的q或esc退出(在英文输入法下)
        break
    
# 关闭摄像头
cap.release()

# 关闭图像窗口
cv2.destroyAllWindows()
  • 核心处理帧函数
# 处理帧函数
def process_frame(img):
    
    '''
    输入摄像头拍摄画面bgr-array,输出图像分类预测结果bgr-array
    '''
    
    # 记录该帧开始处理的时间
    start_time = time.time()
    
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR转RGB
    img_pil = Image.fromarray(img_rgb) # array 转 PIL
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测,得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
    
    top_n = torch.topk(pred_softmax, 5) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析预测类别
    confs = top_n[0].cpu().detach().numpy().squeeze() # 解析置信度
    
    # 在图像上写字
    for i in range(len(confs)):
        pred_class = idx_to_labels[pred_ids[i]]
        text = '{:<15} {:>.3f}'.format(pred_class, confs[i])

        # 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
        img = cv2.putText(img, text, (50, 160 + 80 * i), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4, cv2.LINE_AA)
    
    # 记录该帧处理完毕的时间
    end_time = time.time()
    # 计算每秒处理图像帧数FPS
    FPS = 1/(end_time - start_time)  
    # 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
    img = cv2.putText(img, 'FPS  '+str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4, cv2.LINE_AA)

    return img

你可能感兴趣的:(PyTorch图像分类,pytorch,分类,python)