Pytorch实现MTCNN详解—模型训练阶段1

索引
  • 1.1 Pytorch实现MTCNN详解—原理及结构设计2(MTCNN 基本原理)
  • 1.2 Pytorch实现MTCNN详解—原理及结构设计2(PNet网络的设计和实现)
  • 2.1 Pytorch实现MTCNN详解—模型训练阶段1 (模型训练阶段—预处理数据)
  • 2.2 Pytorch实现MTCNN详解—模型训练阶段2 (模型训练阶段—准备PNet数据)

2.1模型训练阶段—预处理数据

相关连接
  1. 论文地址下载
  2. 本人项目地址(已修复Bug)
  3. 训练数据下载地址
预处理训练数据

在训练MTCNN之前,我们需要收集数据集合,wider_face 提供了训练数据。

⚠️感谢作者,以下代码选自Sierkinhane的GIT,并且做了部分修改已适配新的wider数据。

import os
from scipy.io import loadmat

class DATA:
    def __init__(self, image_name,bboxes):
        self.image_name = image_name
        #self.facenumber = 
        self.bboxes = bboxes

###change MAT file
class WIDER(object):
    def __init__(self, file_to_label, path_to_image=None):
        self.file_to_label = file_to_label
        self.path_to_image = path_to_image

        self.f = loadmat(file_to_label)
        self.event_list = self.f['event_list']
        self.file_list = self.f['file_list']
        self.face_bbx_list = self.f['face_bbx_list']

    def next(self):
        for event_idx, event in enumerate(self.event_list):
            # fix error of "can't not .. bytes and strings"
            e = str(event[0][0].encode('utf-8'))[2:-1]
            for file, bbx in zip(self.file_list[event_idx][0],
                                 self.face_bbx_list[event_idx][0]):
                f = file[0][0].encode('utf-8')
                #print(e, f)  # bytes, bytes
                # fix error of "can't not .. bytes and strings"
                f = str(f)[2:-1]
                # path_of_image = os.path.join(self.path_to_image, str(e), str(f)) + ".jpg"
                path_of_image = self.path_to_image + '/' + e + '/' + f + ".jpg"
                # print(path_of_image)
                bboxes = []
                bbx0 = bbx[0]
                for i in range(bbx0.shape[0]):
                    xmin, ymin, xmax, ymax = bbx0[i]
                    bboxes.append((int(xmin), int(ymin), int(xmax), int(ymax)))
                yield DATA(path_of_image, bboxes)
                    
WIDER_FACE MAT 数据格式

MAT格式无法直接查看,这里我加载后发现,mat主要包含以下内容:

  1. face_bbx_list :图片中人类框的位置(图片一般以左上角当原点)
  2. event_list:images中的子文件夹,以场景命名
  3. file_list:文件名称,一般用来拼接作为图片的路径,用于加载图片
#查看mat文件的属性
import scipy.io as scio

path = '../image/wider_annotation/wider_face_train.mat'
reftracker = scio.loadmat(path)

print(list(reftracker.keys()))
['__header__', '__version__', '__globals__', 'blur_label_list', 'event_list', 'expression_label_list', 'face_bbx_list', 'file_list', 'illumination_label_list', 'invalid_label_list', 'occlusion_label_list', 'pose_label_list']
准备训练数据
import os
import sys
sys.path.append(os.getcwd())
import cv2
import time

"""
 modify .mat to .txt 
"""

#wider face original images path
path_to_image = '../image/wider_train/images'

#matlab file path
file_to_label = '../image/wider_face/wider_face_train.mat'

#target file path
target_file = '../image/anno_train.txt'

wider = WIDER(file_to_label,path_to_image)

line_count = 0
box_count = 0

print('start transforming....')
t = time.time()

with open(target_file, 'w+') as f:
    # press ctrl-C to stop the process
    for data in wider.next():
        line = []
        line.append(str(data.image_name))
        line_count += 1
        for i,box in enumerate(data.bboxes):
            box_count += 1
            for j,bvalue in enumerate(box):
                line.append(str(bvalue))

        line.append('\n')

        line_str = ' '.join(line)
        f.write(line_str)

st = time.time()-t
print('end transforming')

print('spend time:%d'%st)
print('total line(images):%d'%line_count)
print('total boxes(faces):%d'%box_count)
start transforming....
end transforming
spend time:0
total line(images):12880
total boxes(faces):159424
Wider_face 数据说明

如下所示:

Attached the mappings between attribute names and label values.
#图片模糊 清晰、正常、模糊
blur:
  clear->0
  normal blur->1
  heavy blur->2

expression:
  typical expression->0
  exaggerate expression->1

illumination:
  normal illumination->0
  extreme illumination->1

occlusion:
  no occlusion->0
  partial occlusion->1
  heavy occlusion->2

pose:
  typical pose->0
  atypical pose->1

invalid:
  false->0(valid image)
  true->1(invalid image)

The format of txt ground truth.
File name
Number of bounding box
x1, y1, w, h, blur, expression, illumination, invalid, occlusion, pose

注意:
​ bounding box 里面包含了4个值,图片(x1,y1)左上角点,图片(W,H)宽高,Sierkinhane的代码中四个值代表左上角点(x1,y1),右下角点(x2,y2),如果直接使用会有BUG,这里特别说明一下

你可能感兴趣的:(云计算,AI,pytorch)