论文地址:
https://arxiv.org/pdf/1901.00148.pdfhttps://arxiv.org/pdf/1901.00148.pdf
代码地址:
GitHub - megvii-research/MSPN: Multi-Stage Pose NetworkMulti-Stage Pose Network. Contribute to megvii-research/MSPN development by creating an account on GitHub.https://github.com/megvii-research/MSPN
从下面模型结构图可以看出,其实和 CPN的 globalnet 结构挺像,不过CPN的globalnet通道数被固定在256,会在下采样时候丢失信息。MSPN不同,会在下采样对通道数进行增加,尽可能减少信息丢失。而且还增加了个跨 stage的特征融合
首先看下模型代码
class ResNet_downsample_module(nn.Module):
def __init__(self, block, layers, has_skip=False, efficient=False,
zero_init_residual=False):
super(ResNet_downsample_module, self).__init__()
self.has_skip = has_skip
self.in_planes = 64
self.layer1 = self._make_layer(block, 64, layers[0],
efficient=efficient)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
efficient=efficient)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
efficient=efficient)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
efficient=efficient)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out',
nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, efficient=False):
downsample = None
if stride != 1 or self.in_planes != planes * block.expansion:
downsample = conv_bn_relu(self.in_planes, planes * block.expansion,
kernel_size=1, stride=stride, padding=0, has_bn=True,
has_relu=False, efficient=efficient)
layers = list()
layers.append(block(self.in_planes, planes, stride, downsample,
efficient=efficient))
self.in_planes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_planes, planes, efficient=efficient))
return nn.Sequential(*layers)
def forward(self, x, skip1, skip2):
# 结构和 cpn 得global net 结构相似
# CPN通道保持在256 不过这里在下采样会增加通道数目,避免特征信息丢失
# torch.Size([2, 256, 64, 48])
x1 = self.layer1(x)
if self.has_skip:
# 这是跨 stage 得特征融合
x1 = x1 + skip1[0] + skip2[0]
# torch.Size([2, 512, 32, 24])
x2 = self.layer2(x1)
if self.has_skip:
# 这是跨 stage 得特征融合
x2 = x2 + skip1[1] + skip2[1]
# torch.Size([2, 1024, 16, 12])
x3 = self.layer3(x2)
if self.has_skip:
# 这是跨 stage 得特征融合
x3 = x3 + skip1[2] + skip2[2]
# torch.Size([2, 2048, 8, 6])
x4 = self.layer4(x3)
if self.has_skip:
# 这是跨 stage 得特征融合
x4 = x4 + skip1[3] + skip2[3]
return x4, x3, x2, x1
class Upsample_unit(nn.Module):
def __init__(self, ind, in_planes, up_size, output_chl_num, output_shape,
chl_num=256, gen_skip=False, gen_cross_conv=False, efficient=False):
super(Upsample_unit, self).__init__()
self.output_shape = output_shape
self.u_skip = conv_bn_relu(in_planes, chl_num, kernel_size=1, stride=1,
padding=0, has_bn=True, has_relu=False, efficient=efficient)
self.relu = nn.ReLU(inplace=True)
self.ind = ind
if self.ind > 0:
self.up_size = up_size
self.up_conv = conv_bn_relu(chl_num, chl_num, kernel_size=1,
stride=1, padding=0, has_bn=True, has_relu=False,
efficient=efficient)
self.gen_skip = gen_skip
if self.gen_skip:
self.skip1 = conv_bn_relu(in_planes, in_planes, kernel_size=1,
stride=1, padding=0, has_bn=True, has_relu=True,
efficient=efficient)
self.skip2 = conv_bn_relu(chl_num, in_planes, kernel_size=1,
stride=1, padding=0, has_bn=True, has_relu=True,
efficient=efficient)
self.gen_cross_conv = gen_cross_conv
if self.ind == 3 and self.gen_cross_conv:
self.cross_conv = conv_bn_relu(chl_num, 64, kernel_size=1,
stride=1, padding=0, has_bn=True, has_relu=True,
efficient=efficient)
self.res_conv1 = conv_bn_relu(chl_num, chl_num, kernel_size=1,
stride=1, padding=0, has_bn=True, has_relu=True,
efficient=efficient)
self.res_conv2 = conv_bn_relu(chl_num, output_chl_num, kernel_size=3,
stride=1, padding=1, has_bn=True, has_relu=False,
efficient=efficient)
def forward(self, x, up_x):
out = self.u_skip(x)
if self.ind > 0:
# 进行双线性插值到 指定size大小
up_x = F.interpolate(up_x, size=self.up_size, mode='bilinear',
align_corners=True)
up_x = self.up_conv(up_x)
out += up_x
out = self.relu(out)
res = self.res_conv1(out)
res = self.res_conv2(res)
# 每个stage 每层都有一个输出
res = F.interpolate(res, size=self.output_shape, mode='bilinear',
align_corners=True)
skip1 = None
skip2 = None
if self.gen_skip:
# 跨stage特征融合
skip1 = self.skip1(x)
skip2 = self.skip2(out)
cross_conv = None
if self.ind == 3 and self.gen_cross_conv:
cross_conv = self.cross_conv(out)
return out, res, skip1, skip2, cross_conv
class Upsample_module(nn.Module):
def __init__(self, output_chl_num, output_shape, chl_num=256,
gen_skip=False, gen_cross_conv=False, efficient=False):
super(Upsample_module, self).__init__()
self.in_planes = [2048, 1024, 512, 256]
h, w = output_shape
self.up_sizes = [
(h // 8, w // 8), (h // 4, w // 4), (h // 2, w // 2), (h, w)]
self.gen_skip = gen_skip
self.gen_cross_conv = gen_cross_conv
self.up1 = Upsample_unit(0, self.in_planes[0], self.up_sizes[0],
output_chl_num=output_chl_num, output_shape=output_shape,
chl_num=chl_num, gen_skip=self.gen_skip,
gen_cross_conv=self.gen_cross_conv, efficient=efficient)
self.up2 = Upsample_unit(1, self.in_planes[1], self.up_sizes[1],
output_chl_num=output_chl_num, output_shape=output_shape,
chl_num=chl_num, gen_skip=self.gen_skip,
gen_cross_conv=self.gen_cross_conv, efficient=efficient)
self.up3 = Upsample_unit(2, self.in_planes[2], self.up_sizes[2],
output_chl_num=output_chl_num, output_shape=output_shape,
chl_num=chl_num, gen_skip=self.gen_skip,
gen_cross_conv=self.gen_cross_conv, efficient=efficient)
self.up4 = Upsample_unit(3, self.in_planes[3], self.up_sizes[3],
output_chl_num=output_chl_num, output_shape=output_shape,
chl_num=chl_num, gen_skip=self.gen_skip,
gen_cross_conv=self.gen_cross_conv, efficient=efficient)
def forward(self, x4, x3, x2, x1):
# x4: torch.Size([2, 2048, 8, 6])
# x3: torch.Size([2, 2048, 16, 12])
# x2: torch.Size([2, 2048, 32, 24])
# x1: torch.Size([2, 2048, 64, 48])
# out1 : torch.Size([2, 2048, 8, 6])
out1, res1, skip1_1, skip2_1, _ = self.up1(x4, None)
# out2: torch.Size([2, 2048, 16, 12]) 上采用采用的双线性插值
out2, res2, skip1_2, skip2_2, _ = self.up2(x3, out1)
# out3: torch.Size([2, 2048, 32, 24])
out3, res3, skip1_3, skip2_3, _ = self.up3(x2, out2)
# out4: torch.Size([2, 2048, 64, 48])
out4, res4, skip1_4, skip2_4, cross_conv = self.up4(x1, out3)
# 'res' starts from small size
res = [res1, res2, res3, res4]
skip1 = [skip1_4, skip1_3, skip1_2, skip1_1]
skip2 = [skip2_4, skip2_3, skip2_2, skip2_1]
return res, skip1, skip2, cross_conv
class Single_stage_module(nn.Module):
def __init__(self, output_chl_num, output_shape, has_skip=False,
gen_skip=False, gen_cross_conv=False, chl_num=256, efficient=False,
zero_init_residual=False,):
super(Single_stage_module, self).__init__()
self.has_skip = has_skip
self.gen_skip = gen_skip
self.gen_cross_conv = gen_cross_conv
self.chl_num = chl_num
self.zero_init_residual = zero_init_residual
self.layers = [3, 4, 6, 3]
self.downsample = ResNet_downsample_module(Bottleneck, self.layers,
self.has_skip, efficient, self.zero_init_residual)
self.upsample = Upsample_module(output_chl_num, output_shape,
self.chl_num, self.gen_skip, self.gen_cross_conv, efficient)
def forward(self, x, skip1, skip2):
x4, x3, x2, x1 = self.downsample(x, skip1, skip2)
res, skip1, skip2, cross_conv = self.upsample(x4, x3, x2, x1)
return res, skip1, skip2, cross_conv
class MSPN(nn.Module):
def __init__(self, cfg, run_efficient=False, **kwargs):
super(MSPN, self).__init__()
self.top = ResNet_top()
self.stage_num = cfg.MODEL.STAGE_NUM
self.output_chl_num = cfg.DATASET.KEYPOINT.NUM
self.output_shape = cfg.OUTPUT_SHAPE
self.upsample_chl_num = cfg.MODEL.UPSAMPLE_CHANNEL_NUM
self.ohkm = cfg.LOSS.OHKM
self.topk = cfg.LOSS.TOPK
self.ctf = cfg.LOSS.COARSE_TO_FINE
self.mspn_modules = list()
for i in range(self.stage_num):
if i == 0:
has_skip = False
else:
has_skip = True
if i != self.stage_num - 1:
gen_skip = True
gen_cross_conv = True
else:
gen_skip = False
gen_cross_conv = False
self.mspn_modules.append(
Single_stage_module(
self.output_chl_num, self.output_shape,
has_skip=has_skip, gen_skip=gen_skip,
gen_cross_conv=gen_cross_conv,
chl_num=self.upsample_chl_num,
efficient=run_efficient,
**kwargs
)
)
setattr(self, 'stage%d' % i, self.mspn_modules[i])
def forward(self, imgs, valids=None, labels=None):
x = self.top(imgs)
skip1 = None
skip2 = None
outputs = list()
# 两个stage
for i in range(self.stage_num):
# x 分辨率最大的那层输出 skip1, skip2用来实现跨stage特征融合
res, skip1, skip2, x = eval('self.stage' + str(i))(x, skip1, skip2)
outputs.append(res)
if valids is None and labels is None:
return outputs[-1][-1]
else:
return self._calculate_loss(outputs, valids, labels)
分析下损失函数
def _calculate_loss(self, outputs, valids, labels):
# outputs: stg1 -> stg2 -> ... , res1: bottom -> up
# valids: (n, 17, 1), labels: (n, 5, 17, h, w)
# 第一个stage 和 第二个stage对应的label 取得高斯核不一样
# 第一个stage 取 前4个label label[:, 0:4, ...] 第二个stage 取 后4个label label[:, 1:5, ...]
loss1 = JointsL2Loss()
if self.ohkm:
loss2 = JointsL2Loss(has_ohkm=self.ohkm, topk=self.topk)
loss = 0
# 两个stage
for i in range(self.stage_num):
# 每个stage有4层 每层有一个输出
for j in range(4):
ind = j
if i == self.stage_num - 1 and self.ctf:
# 当进入第二个stage label 取的时后面4个
ind += 1
# 取出对应label
tmp_labels = labels[:, ind, :, :, :]
# 采用和 CPN refinenet 使用的 OHKM 计算损失函数
if j == 3 and self.ohkm:
tmp_loss = loss2(outputs[i][j], valids, tmp_labels)
else:
tmp_loss = loss1(outputs[i][j], valids, tmp_labels)
if j < 3:
tmp_loss = tmp_loss / 4
loss += tmp_loss
return dict(total_loss=loss)
接下来看 label的生成 # TRAIN.GAUSSIAN_KERNELS = [(15, 15), (11, 11), (9, 9), (7, 7), (5, 5)]
采用了5个高斯核来进行label生成。可以看出 对于stage2的label 的高斯核取得后四个,会比第一个stage取前四个高斯核要小一点。结果会更精细。在inference时直接使用stage2得最后一层输出。
class JointsDataset(Dataset):
def __init__(self, DATASET, stage, transform=None):
pass
def __len__(self):
return self.data_num
def __getitem__(self, idx):
pass
if self.stage == 'train':
for i in range(self.keypoint_num):
if joints_vis[i, 0] > 0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
if joints[i, 0] < 0 \
or joints[i, 0] > self.input_shape[1] - 1 \
or joints[i, 1] < 0 \
or joints[i, 1] > self.input_shape[0] - 1:
joints_vis[i, 0] = 0
valid = torch.from_numpy(joints_vis).float()
# 采用不同的 高斯核生成 heatmap label值
# TRAIN.GAUSSIAN_KERNELS = [(15, 15), (11, 11), (9, 9), (7, 7), (5, 5)]
labels_num = len(self.gaussian_kernels)
labels = np.zeros(
(labels_num, self.keypoint_num, *self.output_shape))
for i in range(labels_num):
labels[i] = self.generate_heatmap(
joints, valid, kernel=self.gaussian_kernels[i])
labels = torch.from_numpy(labels).float()
return img, valid, labels
else:
return img, score, center, scale, img_id
def generate_heatmap(self, joints, valid, kernel=(7, 7)):
heatmaps = np.zeros(
(self.keypoint_num, *self.output_shape), dtype='float32')
for i in range(self.keypoint_num):
if valid[i] < 1:
continue
target_y = joints[i, 1] * self.output_shape[0] \
/ self.input_shape[0]
target_x = joints[i, 0] * self.output_shape[1] \
/ self.input_shape[1]
heatmaps[i, int(target_y), int(target_x)] = 1
# 这里采用 cv2的高斯模糊来进行高斯函数赋值
heatmaps[i] = cv2.GaussianBlur(heatmaps[i], kernel, 0)
maxi = np.amax(heatmaps[i])
if maxi <= 1e-8:
continue
heatmaps[i] /= maxi / 255
return heatmaps
最后分析下 inference代码
def compute_on_dataset(model, data_loader, device):
model.eval()
results = list()
cpu_device = torch.device("cpu")
data = tqdm(data_loader) if is_main_process() else data_loader
for _, batch in enumerate(data):
# imgs: 模型输入 也就 图片数据
# scores 数据集里面得一个置信度值 0 ~ 1
# centers 当前检测人得 boxes得中心点
# scales 缩放尺寸 pixel_std 默认值 200
# scale = np.array([w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
# dtype=np.float32)
imgs, scores, centers, scales, img_ids = batch
imgs = imgs.to(device)
with torch.no_grad():
# 得到预测结果 最后一层得输出
outputs = model(imgs)
# outputs.cpu().numpy()
outputs = outputs.to(cpu_device).numpy()
# 是否进行翻转 没什么意义
if cfg.TEST.FLIP:
imgs_flipped = np.flip(imgs.to(cpu_device).numpy(), 3).copy()
imgs_flipped = torch.from_numpy(imgs_flipped).to(device)
outputs_flipped = model(imgs_flipped)
outputs_flipped = outputs_flipped.to(cpu_device).numpy()
outputs_flipped = flip_back(
outputs_flipped, cfg.DATASET.KEYPOINT.FLIP_PAIRS)
outputs = (outputs + outputs_flipped) * 0.5
centers = np.array(centers)
scales = np.array(scales)
# 对预测结果进行处理 方便可视化
# preds (B, num_joints, 2) maxvals (B, num_joints, 1)
preds, maxvals = get_results(outputs, centers, scales,
cfg.TEST.GAUSSIAN_KERNEL, cfg.TEST.SHIFT_RATIOS)
kp_scores = maxvals.squeeze().mean(axis=1)
# 得到最终结果 (B, num_joints, 3) 前两个时坐标 后一个时分数
preds = np.concatenate((preds, maxvals), axis=2)
# 对结果进行dict存储
for i in range(preds.shape[0]):
keypoints = preds[i].reshape(-1).tolist()
score = scores[i] * kp_scores[i]
image_id = img_ids[i]
results.append(dict(image_id=image_id,
category_id=1,
keypoints=keypoints,
score=score))
return results
def get_results(outputs, centers, scales, kernel=11, shifts=[0.25]):
# scales 值 (h / 200, w / 200) 这里在 乘以200 恢复到原先 box 大小 (步骤貌似有点多余?不太清楚)
scales *= 200 # 其实就是 w, h
# outputs (B, 17, 64, 48)
nr_img = outputs.shape[0]
# (B, 17, 2) 这里用来存坐标
preds = np.zeros((nr_img, cfg.DATASET.KEYPOINT.NUM, 2))
# (B, 17, 1) 这用来存最大值
maxvals = np.zeros((nr_img, cfg.DATASET.KEYPOINT.NUM, 1))
# 遍历图片
for i in range(nr_img):
# (17, 64, 48)
score_map = outputs[i].copy() # 得到值拷贝
score_map = score_map / 255 + 0.5
# (17, 2)
kps = np.zeros((cfg.DATASET.KEYPOINT.NUM, 2))
# (17, 1)
scores = np.zeros((cfg.DATASET.KEYPOINT.NUM, 1))
border = 10
# 添加边框 border = 10
dr = np.zeros((cfg.DATASET.KEYPOINT.NUM,
cfg.OUTPUT_SHAPE[0] + 2 * border, cfg.OUTPUT_SHAPE[1] + 2 * border))
dr[:, border: -border, border: -border] = outputs[i].copy()
for w in range(cfg.DATASET.KEYPOINT.NUM):
# 进行一次高斯卷积,让值更加平滑 不会改变大小关系
dr[w] = cv2.GaussianBlur(dr[w], (kernel, kernel), 0)
# 遍历所有关键点
for w in range(cfg.DATASET.KEYPOINT.NUM):
# 这里长度为1 没什么 必要
for j in range(len(shifts)):
# 得到最大值和坐标 x,y
if j == 0:
lb = dr[w].argmax()
y, x = np.unravel_index(lb, dr[w].shape)
dr[w, y, x] = 0
x -= border
y -= border
# 得到第二个最大值和坐标 x,y
lb = dr[w].argmax()
py, px = np.unravel_index(lb, dr[w].shape)
dr[w, py, px] = 0
# 从代码上看 得到 最大值得第二大值之间得坐标差
px -= border + x
py -= border + y
# 得到两个坐标点之间得距离
ln = (px ** 2 + py ** 2) ** 0.5
if ln > 1e-3:
# 如果两个点之间得距离大于 1e-3 做个平移 这有什么作用没看懂 (进行微调?)
# px / ln 其实等于 最大值和第二大值得 方向余弦
x += shifts[j] * px / ln
y += shifts[j] * py / ln
# 得到最终得坐标值
x = max(0, min(x, cfg.OUTPUT_SHAPE[1] - 1))
y = max(0, min(y, cfg.OUTPUT_SHAPE[0] - 1))
# 乘以 4 倍 缩放为原图 (x + 0.5)* 4 (y + 0.5)* 4
kps[w] = np.array([x * 4 + 2, y * 4 + 2])
scores[w, 0] = score_map[w, int(round(y) + 1e-9), \
int(round(x) + 1e-9)]
# aligned or not ...
# 进行校正 得到在当前人在原图上得坐标 (Center_x + (x - box_w / 2))
kps[:, 0] = kps[:, 0] / cfg.INPUT_SHAPE[1] * scales[i][0] + \
centers[i][0] - scales[i][0] * 0.5
kps[:, 1] = kps[:, 1] / cfg.INPUT_SHAPE[0] * scales[i][1] + \
centers[i][1] - scales[i][1] * 0.5
preds[i] = kps
maxvals[i] = scores
return preds, maxvals
最后对于 结果得可视化 可以参考 Dataset类中得 可视化函数 visualize
def visualize(self, img, joints, score=None):
pairs = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],
[7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
[1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
color = np.random.randint(0, 256, (self.keypoint_num, 3)).tolist()
for i in range(self.keypoint_num):
if joints[i, 0] > 0 and joints[i, 1] > 0:
cv2.circle(img, tuple(joints[i, :2]), 2, tuple(color[i]), 2)
if score:
cv2.putText(img, score, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.2,
(128, 255, 0), 2)
def draw_line(img, p1, p2):
c = (0, 0, 255)
if p1[0] > 0 and p1[1] > 0 and p2[0] > 0 and p2[1] > 0:
cv2.line(img, tuple(p1), tuple(p2), c, 2)
for pair in pairs:
draw_line(img, joints[pair[0] - 1], joints[pair[1] - 1])
return img
到此,MSPN主要代码分析完。