Rethinking on Multi-Stage Networks for Human Pose Estimation 源码分析

论文地址:

https://arxiv.org/pdf/1901.00148.pdficon-default.png?t=M85Bhttps://arxiv.org/pdf/1901.00148.pdf

代码地址:

GitHub - megvii-research/MSPN: Multi-Stage Pose NetworkMulti-Stage Pose Network. Contribute to megvii-research/MSPN development by creating an account on GitHub.https://github.com/megvii-research/MSPN

从下面模型结构图可以看出,其实和 CPN的 globalnet 结构挺像,不过CPN的globalnet通道数被固定在256,会在下采样时候丢失信息。MSPN不同,会在下采样对通道数进行增加,尽可能减少信息丢失。而且还增加了个跨 stage的特征融合

Rethinking on Multi-Stage Networks for Human Pose Estimation 源码分析_第1张图片

 

首先看下模型代码

class ResNet_downsample_module(nn.Module):

    def __init__(self, block, layers, has_skip=False, efficient=False,
            zero_init_residual=False):
        super(ResNet_downsample_module, self).__init__()
        self.has_skip = has_skip 
        self.in_planes = 64
        self.layer1 = self._make_layer(block, 64, layers[0],
                efficient=efficient)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                efficient=efficient)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                efficient=efficient)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                efficient=efficient)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',
                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, efficient=False):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = conv_bn_relu(self.in_planes, planes * block.expansion,
                    kernel_size=1, stride=stride, padding=0, has_bn=True,
                    has_relu=False, efficient=efficient)

        layers = list() 
        layers.append(block(self.in_planes, planes, stride, downsample,
            efficient=efficient))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, planes, efficient=efficient))

        return nn.Sequential(*layers)

    def forward(self, x, skip1, skip2):
        # 结构和 cpn 得global net 结构相似
        # CPN通道保持在256 不过这里在下采样会增加通道数目,避免特征信息丢失
        # torch.Size([2, 256, 64, 48])
        x1 = self.layer1(x)
        if self.has_skip:
            # 这是跨 stage 得特征融合
            x1 = x1 + skip1[0] + skip2[0]
        # torch.Size([2, 512, 32, 24])
        x2 = self.layer2(x1)
        if self.has_skip:
            # 这是跨 stage 得特征融合
            x2 = x2 + skip1[1] + skip2[1]
        # torch.Size([2, 1024, 16, 12])
        x3 = self.layer3(x2)
        if self.has_skip:
            # 这是跨 stage 得特征融合
            x3 = x3 + skip1[2] + skip2[2]
        # torch.Size([2, 2048, 8, 6])
        x4 = self.layer4(x3)
        if self.has_skip:
            # 这是跨 stage 得特征融合
            x4 = x4 + skip1[3] + skip2[3]

        return x4, x3, x2, x1


class Upsample_unit(nn.Module): 

    def __init__(self, ind, in_planes, up_size, output_chl_num, output_shape,
            chl_num=256, gen_skip=False, gen_cross_conv=False, efficient=False):
        super(Upsample_unit, self).__init__()
        self.output_shape = output_shape

        self.u_skip = conv_bn_relu(in_planes, chl_num, kernel_size=1, stride=1,
                padding=0, has_bn=True, has_relu=False, efficient=efficient)
        self.relu = nn.ReLU(inplace=True)

        self.ind = ind
        if self.ind > 0:
            self.up_size = up_size
            self.up_conv = conv_bn_relu(chl_num, chl_num, kernel_size=1,
                    stride=1, padding=0, has_bn=True, has_relu=False,
                    efficient=efficient)

        self.gen_skip = gen_skip
        if self.gen_skip:
            self.skip1 = conv_bn_relu(in_planes, in_planes, kernel_size=1,
                    stride=1, padding=0, has_bn=True, has_relu=True,
                    efficient=efficient)
            self.skip2 = conv_bn_relu(chl_num, in_planes, kernel_size=1,
                    stride=1, padding=0, has_bn=True, has_relu=True,
                    efficient=efficient)

        self.gen_cross_conv = gen_cross_conv
        if self.ind == 3 and self.gen_cross_conv:
            self.cross_conv = conv_bn_relu(chl_num, 64, kernel_size=1,
                    stride=1, padding=0, has_bn=True, has_relu=True,
                    efficient=efficient)

        self.res_conv1 = conv_bn_relu(chl_num, chl_num, kernel_size=1,
                stride=1, padding=0, has_bn=True, has_relu=True,
                efficient=efficient)
        self.res_conv2 = conv_bn_relu(chl_num, output_chl_num, kernel_size=3,
                stride=1, padding=1, has_bn=True, has_relu=False,
                efficient=efficient)

    def forward(self, x, up_x):
        out = self.u_skip(x)

        if self.ind > 0:
            # 进行双线性插值到 指定size大小
            up_x = F.interpolate(up_x, size=self.up_size, mode='bilinear',
                    align_corners=True)
            up_x = self.up_conv(up_x)
            out += up_x 
        out = self.relu(out)

        res = self.res_conv1(out)
        res = self.res_conv2(res)
        # 每个stage 每层都有一个输出
        res = F.interpolate(res, size=self.output_shape, mode='bilinear',
                align_corners=True)

        skip1 = None
        skip2 = None
        if self.gen_skip:
            # 跨stage特征融合
            skip1 = self.skip1(x)
            skip2 = self.skip2(out)

        cross_conv = None
        if self.ind == 3 and self.gen_cross_conv:
            cross_conv = self.cross_conv(out)

        return out, res, skip1, skip2, cross_conv


class Upsample_module(nn.Module):

    def __init__(self, output_chl_num, output_shape, chl_num=256,
            gen_skip=False, gen_cross_conv=False, efficient=False):
        super(Upsample_module, self).__init__()
        self.in_planes = [2048, 1024, 512, 256] 
        h, w = output_shape
        self.up_sizes = [
                (h // 8, w // 8), (h // 4, w // 4), (h // 2, w // 2), (h, w)]
        self.gen_skip = gen_skip
        self.gen_cross_conv = gen_cross_conv

        self.up1 = Upsample_unit(0, self.in_planes[0], self.up_sizes[0],
                output_chl_num=output_chl_num, output_shape=output_shape,
                chl_num=chl_num, gen_skip=self.gen_skip,
                gen_cross_conv=self.gen_cross_conv, efficient=efficient)
        self.up2 = Upsample_unit(1, self.in_planes[1], self.up_sizes[1],
                output_chl_num=output_chl_num, output_shape=output_shape,
                chl_num=chl_num, gen_skip=self.gen_skip,
                gen_cross_conv=self.gen_cross_conv, efficient=efficient)
        self.up3 = Upsample_unit(2, self.in_planes[2], self.up_sizes[2],
                output_chl_num=output_chl_num, output_shape=output_shape,
                chl_num=chl_num, gen_skip=self.gen_skip,
                gen_cross_conv=self.gen_cross_conv, efficient=efficient)
        self.up4 = Upsample_unit(3, self.in_planes[3], self.up_sizes[3],
                output_chl_num=output_chl_num, output_shape=output_shape,
                chl_num=chl_num, gen_skip=self.gen_skip,
                gen_cross_conv=self.gen_cross_conv, efficient=efficient)

    def forward(self, x4, x3, x2, x1):
        # x4: torch.Size([2, 2048, 8, 6])
        # x3: torch.Size([2, 2048, 16, 12])
        # x2: torch.Size([2, 2048, 32, 24])
        # x1: torch.Size([2, 2048, 64, 48])
        # out1 : torch.Size([2, 2048, 8, 6])
        out1, res1, skip1_1, skip2_1, _ = self.up1(x4, None)
        # out2: torch.Size([2, 2048, 16, 12]) 上采用采用的双线性插值
        out2, res2, skip1_2, skip2_2, _ = self.up2(x3, out1)
        # out3: torch.Size([2, 2048, 32, 24])
        out3, res3, skip1_3, skip2_3, _ = self.up3(x2, out2)
        # out4: torch.Size([2, 2048, 64, 48])
        out4, res4, skip1_4, skip2_4, cross_conv = self.up4(x1, out3)

        # 'res' starts from small size
        res = [res1, res2, res3, res4]
        skip1 = [skip1_4, skip1_3, skip1_2, skip1_1]
        skip2 = [skip2_4, skip2_3, skip2_2, skip2_1]

        return res, skip1, skip2, cross_conv


class Single_stage_module(nn.Module):

    def __init__(self, output_chl_num, output_shape, has_skip=False,
            gen_skip=False, gen_cross_conv=False, chl_num=256, efficient=False,
            zero_init_residual=False,):
        super(Single_stage_module, self).__init__()
        self.has_skip = has_skip
        self.gen_skip = gen_skip
        self.gen_cross_conv = gen_cross_conv
        self.chl_num = chl_num
        self.zero_init_residual = zero_init_residual 
        self.layers = [3, 4, 6, 3]
        self.downsample = ResNet_downsample_module(Bottleneck, self.layers,
                self.has_skip, efficient, self.zero_init_residual)
        self.upsample = Upsample_module(output_chl_num, output_shape,
                self.chl_num, self.gen_skip, self.gen_cross_conv, efficient)

    def forward(self, x, skip1, skip2):
        x4, x3, x2, x1 = self.downsample(x, skip1, skip2)
        res, skip1, skip2, cross_conv = self.upsample(x4, x3, x2, x1)
        
        return res, skip1, skip2, cross_conv


class MSPN(nn.Module):
    
    def __init__(self, cfg, run_efficient=False, **kwargs):
        super(MSPN, self).__init__()
        self.top = ResNet_top()
        self.stage_num = cfg.MODEL.STAGE_NUM
        self.output_chl_num = cfg.DATASET.KEYPOINT.NUM
        self.output_shape = cfg.OUTPUT_SHAPE
        self.upsample_chl_num = cfg.MODEL.UPSAMPLE_CHANNEL_NUM
        self.ohkm = cfg.LOSS.OHKM
        self.topk = cfg.LOSS.TOPK
        self.ctf = cfg.LOSS.COARSE_TO_FINE
        self.mspn_modules = list() 
        for i in range(self.stage_num):
            if i == 0:
                has_skip = False
            else:
                has_skip = True
            if i != self.stage_num - 1:
                gen_skip = True
                gen_cross_conv = True
            else:
                gen_skip = False 
                gen_cross_conv = False 
            self.mspn_modules.append(
                    Single_stage_module(
                        self.output_chl_num, self.output_shape,
                        has_skip=has_skip, gen_skip=gen_skip,
                        gen_cross_conv=gen_cross_conv,
                        chl_num=self.upsample_chl_num,
                        efficient=run_efficient,
                        **kwargs
                        )
                    )
            setattr(self, 'stage%d' % i, self.mspn_modules[i])

   
        
    def forward(self, imgs, valids=None, labels=None):
        x = self.top(imgs)
        skip1 = None
        skip2 = None
        outputs = list()
        # 两个stage
        for i in range(self.stage_num):
            # x 分辨率最大的那层输出  skip1, skip2用来实现跨stage特征融合
            res, skip1, skip2, x = eval('self.stage' + str(i))(x, skip1, skip2)
            outputs.append(res)

        if valids is None and labels is None:
            return outputs[-1][-1]
        else:
            return self._calculate_loss(outputs, valids, labels)

分析下损失函数

    def _calculate_loss(self, outputs, valids, labels):
        # outputs: stg1 -> stg2 -> ... , res1: bottom -> up
        # valids: (n, 17, 1), labels: (n, 5, 17, h, w)
        # 第一个stage 和 第二个stage对应的label 取得高斯核不一样
        # 第一个stage 取 前4个label label[:, 0:4, ...]  第二个stage 取 后4个label  label[:, 1:5, ...]
        loss1 = JointsL2Loss()
        if self.ohkm:
            loss2 = JointsL2Loss(has_ohkm=self.ohkm, topk=self.topk)
        
        loss = 0
        # 两个stage
        for i in range(self.stage_num):
            # 每个stage有4层 每层有一个输出
            for j in range(4):
                ind = j
                if i == self.stage_num - 1 and self.ctf:
                    # 当进入第二个stage label 取的时后面4个
                    ind += 1
                # 取出对应label
                tmp_labels = labels[:, ind, :, :, :]

                # 采用和 CPN refinenet 使用的 OHKM 计算损失函数 
                if j == 3 and self.ohkm:
                    tmp_loss = loss2(outputs[i][j], valids, tmp_labels)
                else:
                    tmp_loss = loss1(outputs[i][j], valids, tmp_labels)

                if j < 3:
                    tmp_loss = tmp_loss / 4

                loss += tmp_loss

        return dict(total_loss=loss)

接下来看 label的生成  # TRAIN.GAUSSIAN_KERNELS = [(15, 15), (11, 11), (9, 9), (7, 7), (5, 5)]

采用了5个高斯核来进行label生成。可以看出 对于stage2的label 的高斯核取得后四个,会比第一个stage取前四个高斯核要小一点。结果会更精细。在inference时直接使用stage2得最后一层输出。

class JointsDataset(Dataset):

    def __init__(self, DATASET, stage, transform=None):
        pass

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        pass
        if self.stage == 'train':
            for i in range(self.keypoint_num):
                if joints_vis[i, 0] > 0:
                    joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
                    if joints[i, 0] < 0 \
                            or joints[i, 0] > self.input_shape[1] - 1 \
                            or joints[i, 1] < 0 \
                            or joints[i, 1] > self.input_shape[0] - 1:
                        joints_vis[i, 0] = 0
            valid = torch.from_numpy(joints_vis).float()
            # 采用不同的 高斯核生成 heatmap label值
            # TRAIN.GAUSSIAN_KERNELS = [(15, 15), (11, 11), (9, 9), (7, 7), (5, 5)]
            labels_num = len(self.gaussian_kernels)
            labels = np.zeros(
                    (labels_num, self.keypoint_num, *self.output_shape))
            for i in range(labels_num):
                labels[i] = self.generate_heatmap(
                        joints, valid, kernel=self.gaussian_kernels[i])
            labels = torch.from_numpy(labels).float()

            return img, valid, labels
        else:
            return img, score, center, scale, img_id

    def generate_heatmap(self, joints, valid, kernel=(7, 7)):
        heatmaps = np.zeros(
                (self.keypoint_num, *self.output_shape), dtype='float32')

        for i in range(self.keypoint_num):
            if valid[i] < 1:
                continue
            target_y = joints[i, 1] * self.output_shape[0] \
                    / self.input_shape[0]
            target_x = joints[i, 0] * self.output_shape[1] \
                    / self.input_shape[1]
            heatmaps[i, int(target_y), int(target_x)] = 1
            # 这里采用 cv2的高斯模糊来进行高斯函数赋值
            heatmaps[i] = cv2.GaussianBlur(heatmaps[i], kernel, 0)
            maxi = np.amax(heatmaps[i])
            if maxi <= 1e-8:
                continue
            heatmaps[i] /= maxi / 255

        return heatmaps 

最后分析下 inference代码

def compute_on_dataset(model, data_loader, device):
    model.eval()
    results = list() 
    cpu_device = torch.device("cpu")

    data = tqdm(data_loader) if is_main_process() else data_loader
    for _, batch in enumerate(data):
        # imgs: 模型输入 也就 图片数据
        # scores 数据集里面得一个置信度值 0 ~ 1
        # centers 当前检测人得 boxes得中心点
        # scales 缩放尺寸  pixel_std 默认值 200
        # scale = np.array([w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
        #                 dtype=np.float32)
        imgs, scores, centers, scales, img_ids = batch

        imgs = imgs.to(device)
        with torch.no_grad():
            # 得到预测结果 最后一层得输出
            outputs = model(imgs)
            # outputs.cpu().numpy()
            outputs = outputs.to(cpu_device).numpy()

            # 是否进行翻转 没什么意义
            if cfg.TEST.FLIP:
                imgs_flipped = np.flip(imgs.to(cpu_device).numpy(), 3).copy()
                imgs_flipped = torch.from_numpy(imgs_flipped).to(device)
                outputs_flipped = model(imgs_flipped)
                outputs_flipped = outputs_flipped.to(cpu_device).numpy()
                outputs_flipped = flip_back(
                        outputs_flipped, cfg.DATASET.KEYPOINT.FLIP_PAIRS)
                
                outputs = (outputs + outputs_flipped) * 0.5

        centers = np.array(centers)
        scales = np.array(scales)
        # 对预测结果进行处理 方便可视化
        # preds (B, num_joints, 2) maxvals (B, num_joints, 1)
        preds, maxvals = get_results(outputs, centers, scales,
                cfg.TEST.GAUSSIAN_KERNEL, cfg.TEST.SHIFT_RATIOS)

        kp_scores = maxvals.squeeze().mean(axis=1)
        # 得到最终结果 (B, num_joints, 3) 前两个时坐标 后一个时分数
        preds = np.concatenate((preds, maxvals), axis=2)

        # 对结果进行dict存储
        for i in range(preds.shape[0]):
            keypoints = preds[i].reshape(-1).tolist()
            score = scores[i] * kp_scores[i]
            image_id = img_ids[i]

            results.append(dict(image_id=image_id,
                                category_id=1,
                                keypoints=keypoints,
                                score=score))

    return results
def get_results(outputs, centers, scales, kernel=11, shifts=[0.25]):
    # scales 值  (h / 200, w / 200) 这里在 乘以200 恢复到原先 box 大小 (步骤貌似有点多余?不太清楚)
    scales *= 200  # 其实就是 w, h
    # outputs (B, 17, 64, 48)
    nr_img = outputs.shape[0]
    # (B, 17, 2) 这里用来存坐标
    preds = np.zeros((nr_img, cfg.DATASET.KEYPOINT.NUM, 2))
    # (B, 17, 1) 这用来存最大值
    maxvals = np.zeros((nr_img, cfg.DATASET.KEYPOINT.NUM, 1))
    # 遍历图片
    for i in range(nr_img):
        # (17, 64, 48)
        score_map = outputs[i].copy() # 得到值拷贝
        score_map = score_map / 255 + 0.5
        # (17, 2)
        kps = np.zeros((cfg.DATASET.KEYPOINT.NUM, 2))
        # (17, 1)
        scores = np.zeros((cfg.DATASET.KEYPOINT.NUM, 1))
        border = 10
        # 添加边框 border = 10
        dr = np.zeros((cfg.DATASET.KEYPOINT.NUM,
                       cfg.OUTPUT_SHAPE[0] + 2 * border, cfg.OUTPUT_SHAPE[1] + 2 * border))
        dr[:, border: -border, border: -border] = outputs[i].copy()
        for w in range(cfg.DATASET.KEYPOINT.NUM):
            # 进行一次高斯卷积,让值更加平滑 不会改变大小关系
            dr[w] = cv2.GaussianBlur(dr[w], (kernel, kernel), 0)
        # 遍历所有关键点
        for w in range(cfg.DATASET.KEYPOINT.NUM):
            # 这里长度为1 没什么 必要
            for j in range(len(shifts)):
                # 得到最大值和坐标 x,y
                if j == 0:
                    lb = dr[w].argmax()
                    y, x = np.unravel_index(lb, dr[w].shape)
                    dr[w, y, x] = 0
                    x -= border
                    y -= border
                # 得到第二个最大值和坐标 x,y
                lb = dr[w].argmax()
                py, px = np.unravel_index(lb, dr[w].shape)
                dr[w, py, px] = 0
                # 从代码上看 得到 最大值得第二大值之间得坐标差
                px -= border + x
                py -= border + y
                # 得到两个坐标点之间得距离
                ln = (px ** 2 + py ** 2) ** 0.5
                if ln > 1e-3:
                    # 如果两个点之间得距离大于 1e-3 做个平移 这有什么作用没看懂 (进行微调?)
                    # px / ln 其实等于 最大值和第二大值得 方向余弦
                    x += shifts[j] * px / ln
                    y += shifts[j] * py / ln
            # 得到最终得坐标值
            x = max(0, min(x, cfg.OUTPUT_SHAPE[1] - 1))
            y = max(0, min(y, cfg.OUTPUT_SHAPE[0] - 1))
            # 乘以 4 倍 缩放为原图 (x + 0.5)* 4 (y + 0.5)* 4
            kps[w] = np.array([x * 4 + 2, y * 4 + 2])
            scores[w, 0] = score_map[w, int(round(y) + 1e-9), \
                                     int(round(x) + 1e-9)]
        # aligned or not ...
        # 进行校正 得到在当前人在原图上得坐标 (Center_x + (x - box_w / 2))
        kps[:, 0] = kps[:, 0] / cfg.INPUT_SHAPE[1] * scales[i][0] + \
                    centers[i][0] - scales[i][0] * 0.5
        kps[:, 1] = kps[:, 1] / cfg.INPUT_SHAPE[0] * scales[i][1] + \
                    centers[i][1] - scales[i][1] * 0.5
        preds[i] = kps
        maxvals[i] = scores

    return preds, maxvals

最后对于 结果得可视化 可以参考 Dataset类中得 可视化函数 visualize

    def visualize(self, img, joints, score=None):
        pairs = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],
                [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
                [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
        color = np.random.randint(0, 256, (self.keypoint_num, 3)).tolist()

        for i in range(self.keypoint_num):
            if joints[i, 0] > 0 and joints[i, 1] > 0:
                cv2.circle(img, tuple(joints[i, :2]), 2, tuple(color[i]), 2)
        if score:
            cv2.putText(img, score, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.2,
                    (128, 255, 0), 2)

        def draw_line(img, p1, p2):
            c = (0, 0, 255)
            if p1[0] > 0 and p1[1] > 0 and p2[0] > 0 and p2[1] > 0:
                cv2.line(img, tuple(p1), tuple(p2), c, 2)

        for pair in pairs:
            draw_line(img, joints[pair[0] - 1], joints[pair[1] - 1])

        return img

到此,MSPN主要代码分析完。

你可能感兴趣的:(源码分析,论文解析,计算机视觉)