编写可视化代码

#coding=utf-8
# !/usr/bin/env python
import h5py
from scipy.io import loadmat
import scipy.misc as scm
import matplotlib.pyplot as plt

JointsIndex = {'r_ankle': 0, 'r_knee': 1, 'r_hip': 2,
               'l_hip': 3,   'l_knee': 4, 'l_ankle': 5,
               'pelvis': 6,  'thorax': 7, 'neck': 8, 'head': 9,
               'r_wrist': 10, 'r_elbow': 11, 'r_shoulder': 12,
               'l_shoulder': 13, 'l_elbow': 14, 'l_wrist': 15}
JointPairs = [['head', 'neck'], ['neck', 'thorax'],
              ['thorax', 'r_shoulder'], ['thorax', 'l_shoulder'], \
              ['r_shoulder', 'r_elbow'], ['r_elbow', 'r_wrist'],
              ['l_shoulder', 'l_elbow'], ['l_elbow', 'l_wrist'], \
              ['pelvis', 'r_hip'], ['pelvis', 'l_hip'], ['r_hip', 'r_knee'],
              ['r_knee', 'r_ankle'], \
              ['l_hip', 'l_knee'], ['l_knee', 'l_ankle'],
              ['thorax', 'pelvis']]

# JointsIndex = {'r_ankle': 0, 'r_knee': 1, 'r_hip': 2,
#                'l_hip': 3,   'l_knee': 4, 'l_ankle': 5,
#                'r_wrist': 6, 'r_elbow': 7, 'r_shoulder': 8,
#                'l_shoulder': 9, 'l_elbow': 10, 'l_wrist': 11,'neck': 12, 'head_top': 13}
#
# JointPairs = [['head_top', 'neck'],
#               ['neck', 'r_shoulder'], ['neck', 'l_shoulder'],\
#               ['r_shoulder', 'r_elbow'],['r_elbow', 'r_wrist'],
#              ['l_shoulder', 'l_elbow'], ['l_elbow', 'l_wrist'],\
#               ['r_shoulder', 'r_hip'],['r_hip', 'r_knee'],
#               ['l_shoulder', 'l_hip'],['l_hip', 'l_knee'],\
#               ['r_knee', 'r_ankle'],
#                ['l_knee', 'l_ankle']]
# ['r_hip', 'l_hip'],

# StickType = ['g','g','y','r','r','y','m','m','y','b','b','y','c','c']
StickType = ['r-', 'r-', 'g-', 'b-', 'g-', 'g-', 'b-', 'b-', 'c-', 'm-',
             'c-', 'c-', 'm-', 'm-', 'r-']


# imgs = open('/media/z/CC/zxl/shujuji/LEEDS/lsp_dataset/lsp.txt','r').readlines()
# images_path = '/media/z/CC/zxl/shujuji/LEEDS/lsp_dataset/images/'
imgs = open('/media/z/CC/zxl/shujuji/mpii.txt','r').readlines()
images_path = '/media/z/CC/zxl/convolutional-pose-machines-release/dataset/MPI/Tompson_valid/images/'


# f = h5py.File('preds.h5','r')
# model_name = 'hg2'
# predfile = '/media/z/CC/zxl/pytorch-pose/checkpoint/lsp_hg2_130/' + model_name + '/preds_valid.mat'
model_name = 'hg8'
predfile = '/media/z/CC/zxl/pytorch-pose/checkpoint8/mpii120/' + model_name + '/preds_valid.mat'
preds = loadmat(predfile)['preds']


# f = h5py.File('preds.h5','r')#要是读取文件的话,就把w换成r
# f_keys = f.keys()#可以查看所有的主键
# #imgs = f['imgs'][:]
# preds = f['preds'][:]
# f.close()#关闭文件

# f_keys = f.keys()
# # imgs = f['imgs'][:]
#preds = f['preds'][:]
# f.close()

# assert len(imgs) == len(preds)
for i in range(len(preds)):
    filename = images_path + imgs[i][:-1]
    img = scm.imread(filename)

    pose = preds[i]
    # img = imgs[i].transpose(1,2,0)
    plt.axis('off')
    plt.imshow(img)
    # for i in range(16):
    #     if pose[i][0] > 0 and pose[i][1] > 0:
    #         plt.scatter(pose[i][0], pose[i][1], marker='o', color='r', s=15)
    # plt.show()

    for i in range(len(JointPairs)):
        idx1 = JointsIndex[JointPairs[i][0]]
        idx2 = JointsIndex[JointPairs[i][1]]
        print(JointPairs[i][0])

        if pose[idx1][0] > 0 and pose[idx1][1] > 0 and \
                pose[idx2][0] > 0 and pose[idx2][1] > 0:
            joints_x = [pose[idx1][0], pose[idx2][0]]
            joints_y = [pose[idx1][1], pose[idx2][1]]

            plt.plot(joints_x, joints_y, StickType[i], linewidth=3)
    plt.show()

print 'Done.'

你可能感兴趣的:(编写可视化代码)