NTU-RGB+D数据处理(CTR-GCN ICCV2021)

文章目录

  • get_raw_skes_data.py
  • get_raw_denoised_data.py
  • seq_transformation.py

get_raw_skes_data.py

'''
	本文件主要生成四个文件
	raw_data
	  ├── frames_cnt.txt
	  │   		├── num_frame_interval(list)
	  ├── frames_drop.log
	  │   		├── num_frame_interval(list)
	  ├── frames_drop_skes.pkl
	  │   		├── drop_skes(list) 没有人出现的帧,序列号
	  ├── raw_skes_data.pkl
	  │   		├── name(key): ske_name(value, str) 	 
	  │   		├── num_frames(key): num_frames - num_frames_drop(value, int) 
	  │   		├── data(key): bodies_data(value, dict) 
	  │   		│   		├── bodyID(key):body_data(value, dict)
	  │   		│   		│   		├── joints(key)     value(joints[num_body, 25, 3])
	  │   		│   		│   		├── colors(key)     value(colors[num_body, 25, 2])
	  │   		│   		│   		├── interval(key)    value(interval[int])
	  │   		│   		│   		│---- motion(key)     value(int)  # 这部分只有当双人时才会有;
'''
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os.path as osp
import os
import numpy as np
import pickle
import logging


def get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger):
    """
    Get raw bodies data from a skeleton sequence.

    Each body's data is a dict that contains the following keys:
      - joints: raw 3D joints positions. Shape: (num_frames x 25, 3)
      - colors: raw 2D color locations. Shape: (num_frames, 25, 2)
      - interval: a list which stores the frame indices of this body.
      - motion: motion amount (only for the sequence with 2 or more bodyIDs).

    body_info =[ 'bodyID', 'clipedEdges', 'handLeftConfidence',
                     'handLeftState', 'handRightConfidence', 'handRightState',
                    'isResticted', 'leanX', 'leanY', 'trackingState' ]*/
    25    # 25个关节点 25*12  每个关节包含12个字段信息
    joint_info_key = [
                        'x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY',
                        'orientationW', 'orientationX', 'orientationY',
                        'orientationZ', 'trackingState'
                    ]

    Return:
      a dict for a skeleton sequence with 3 key-value pairs:
        - name: the skeleton filename.
        - data: a dict which stores raw data of each body.
        - num_frames: the number of valid frames.
    """
    ske_file = osp.join(skes_path, ske_name + '.skeleton')
    assert osp.exists(ske_file), 'Error: Skeleton file %s not found' % ske_file
    # Read all data from .skeleton file into a list (in string format) "S001C001P001R001A001.skeleton"
    print('Reading data from %s' % ske_file[-29:])
    with open(ske_file, 'r') as fr:
        str_data = fr.readlines()

    num_frames = int(str_data[0].strip('\r\n')) # 提取第一行帧数
    frames_drop = []
    bodies_data = dict()
    valid_frames = -1  # 0-based index
    current_line = 1

    for f in range(num_frames):
        num_bodies = int(str_data[current_line].strip('\r\n'))      # 当前帧有几个人
        current_line += 1

        if num_bodies == 0:  # no data in this frame, drop it
            frames_drop.append(f)  # 0-based index
            continue

        valid_frames += 1   # 有效帧
        joints = np.zeros((num_bodies, 25, 3), dtype=np.float32)
        colors = np.zeros((num_bodies, 25, 2), dtype=np.float32)

        # 开始计算人的关节点
        for b in range(num_bodies):
            bodyID = str_data[current_line].strip('\r\n').split()[0]    # ID
            current_line += 1
            num_joints = int(str_data[current_line].strip('\r\n'))  # 25 joints
            current_line += 1
            # 关节点坐标
            for j in range(num_joints):
                temp_str = str_data[current_line].strip('\r\n').split() # 去除最后的回车符号,并按照空格分离字符串
                joints[b, j, :] = np.array(temp_str[:3], dtype=np.float32)
                colors[b, j, :] = np.array(temp_str[5:7], dtype=np.float32)
                current_line += 1
            # 将当前bodyID加入bodies_data中,并将关节点坐标输入字典中
            if bodyID not in bodies_data:  # Add a new body's data
                body_data = dict()
                body_data['joints'] = joints[b]  # ndarray: (25, 3)
                body_data['colors'] = colors[b, np.newaxis]  # ndarray: (1, 25, 2), 拓展维度
                body_data['interval'] = [valid_frames]  # the index of the first frame
            else:  # Update an already existed body's data
                body_data = bodies_data[bodyID]
                # Stack each body's data of each frame along the frame order
                body_data['joints'] = np.vstack((body_data['joints'], joints[b])) # 相当于是append,将数组堆叠,25个关节点依次堆叠
                body_data['colors'] = np.vstack((body_data['colors'], colors[b, np.newaxis]))
                pre_frame_idx = body_data['interval'][-1]
                body_data['interval'].append(pre_frame_idx + 1)  # add a new frame index

            bodies_data[bodyID] = body_data  # Update bodies_data

    num_frames_drop = len(frames_drop)
    assert num_frames_drop < num_frames, \
        'Error: All frames data (%d) of %s is missing or lost' % (num_frames, ske_name)
    if num_frames_drop > 0:
        frames_drop_skes[ske_name] = np.array(frames_drop, dtype=np.int)
        frames_drop_logger.info('{}: {} frames missed: {}\n'.format(ske_name, num_frames_drop,
                                                                    frames_drop))

    # Calculate motion (only for the sequence with 2 or more bodyIDs)
    if len(bodies_data) > 1:
        for body_data in bodies_data.values():
            body_var = np.var(body_data['joints'], axis=0)  # 获取身体关节点数据方差
            body_data['motion'] = np.sum(body_var)
    '''
    bodies_data(dict)
    ├── bodyID(key)
    │   ├── body_data(dict)
    │   │   ├── joints(key)     value(joints[num_body, 25, 3])
    │   │   ├── colors(key)     value(colors[num_body, 25, 2])
    │   │   ├── interval(key)   value(interval[int])
    │   │   ├── motion(key)     value(int)  # 这部分只有当双人时才会有;
    '''
    return {'name': ske_name, 'data': bodies_data, 'num_frames': num_frames - num_frames_drop}


def get_raw_skes_data():
    # # save_path = './data'
    # # skes_path = '/data/pengfei/NTU/nturgb+d_skeletons/'
    # stat_path = osp.join(save_path, 'statistics')
    #
    # skes_name_file = osp.join(stat_path, 'skes_available_name.txt')
    # save_data_pkl = osp.join(save_path, 'raw_skes_data.pkl')
    # frames_drop_pkl = osp.join(save_path, 'frames_drop_skes.pkl')
    #
    # frames_drop_logger = logging.getLogger('frames_drop')
    # frames_drop_logger.setLevel(logging.INFO)
    # frames_drop_logger.addHandler(logging.FileHandler(osp.join(save_path, 'frames_drop.log')))
    # frames_drop_skes = dict()

    skes_name = np.loadtxt(skes_name_file, dtype=str)

    num_files = skes_name.size
    print('Found %d available skeleton files.' % num_files)

    raw_skes_data = []
    frames_cnt = np.zeros(num_files, dtype=np.int)  # 有多少个视频就建立多长的list,保存的数据为list,向量中每个数值表示为视频中有多少帧

    for (idx, ske_name) in enumerate(skes_name):
        bodies_data = get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger)    # 每次处理一个文件
        # 回传参数为字典型{'name': ske_name, 'data': bodies_data, 'num_frames': num_frames - num_frames_drop}
        raw_skes_data.append(bodies_data)
        frames_cnt[idx] = bodies_data['num_frames']
        if (idx + 1) % 1000 == 0:
            print('Processed: %.2f%% (%d / %d)' % \
                  (100.0 * (idx + 1) / num_files, idx + 1, num_files))      # ‘Processed: 1.77% (1000 / 56578)’每间隔1000帧,输出一下结果

    with open(save_data_pkl, 'wb') as fw:
        pickle.dump(raw_skes_data, fw, pickle.HIGHEST_PROTOCOL)     # 保存有效帧中人体骨架数据,pickle.HIGHEST_PROTOCOL,可用的最高协议版本
    np.savetxt(osp.join(save_path, 'raw_data', 'frames_cnt.txt'), frames_cnt, fmt='%d') # fmt写入文件的格式

    print('Saved raw bodies data into %s' % save_data_pkl)
    print('Total frames: %d' % np.sum(frames_cnt))

    with open(frames_drop_pkl, 'wb') as fw:
        pickle.dump(frames_drop_skes, fw, pickle.HIGHEST_PROTOCOL)     # 保存没有人的帧位置

if __name__ == '__main__':
    save_path = './'

    skes_path = 'D:/project/data/raw/NTU-RGBD/nturgb+d_skeletons'
    stat_path = osp.join(save_path, 'statistics')
    if not osp.exists('./raw_data'):
        os.makedirs('./raw_data')

    skes_name_file = osp.join(stat_path, 'skes_available_name.txt')  # 所有*.skeleton文件名字
    save_data_pkl = osp.join(save_path, 'raw_data', 'raw_skes_data.pkl')    # *.pkl保存位置原始数据的保存文件
    frames_drop_pkl = osp.join(save_path, 'raw_data', 'frames_drop_skes.pkl')   # 需要去除的文件

    frames_drop_logger = logging.getLogger('frames_drop')   # 初始化日志对象
    frames_drop_logger.setLevel(logging.INFO)       # 设置日志等级 info 级输出,重要信息
    frames_drop_logger.addHandler(logging.FileHandler(osp.join(save_path, 'raw_data', 'frames_drop.log'))) # FileHandler新建.log文件,之后加载该对象
    frames_drop_skes = dict()

    get_raw_skes_data()

    # with open(frames_drop_pkl, 'wb') as fw:
    #     pickle.dump(frames_drop_skes, fw, pickle.HIGHEST_PROTOCOL)
        

没有空帧情况
S001C001P001R001A001.skeleton
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3GeWtgAX-1639055585576)(:/75ddd4c6f1594c8f851e581ee081b924)]

双人情况

70   # 单个样本包含的总帧数 
2     #从第二行开始分别为每一帧的信息,对于其中每一帧,第一个数字为当前帧body数量(如1或2)
72057594037944738 0 1 1 1 1 0 0.01119184 -0.256052 2  
# body_info =[ 'bodyID', 'clipedEdges', 'handLeftConfidence',
#                  'handLeftState', 'handRightConfidence', 'handRightState',
#                 'isResticted', 'leanX', 'leanY', 'trackingState' ]*/
25    # 25个关节点 25*12  每个关节包含12个字段信息
# joint_info_key = [
#                     'x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY',
#                     'orientationW', 'orientationX', 'orientationY',
#                     'orientationZ', 'trackingState'
#                 ]
0.7201789 0.1647426 3.561895 336.5812 192.8021 1184.537 489.4118 -0.2218947 0.03421978 0.968875 -0.1042736 2
0.7260004 0.4483242 3.498769 338.6059 162.7653 1190.205 402.5223 -0.2523348 0.03749071 0.9613965 -0.1031427 2
0.7291254 0.7261137 3.425736 340.6862 131.9133 1195.987 313.4265 -0.2830125 0.04291317 0.9484802 -0.135822 2
0.7023438 0.8463342 3.396255 338.5144 118.1859 1189.627 273.8169 0 0 0 0 2
0.5827141 0.6567038 3.536651 322.9463 141.66 1144.507 341.585 0.2607923 0.6977255 -0.6554007 -0.124966 2
0.5192263 0.3914464 3.649472 314.598 170.4908 1120.394 424.959 -0.03581179 -0.6068473 0.2222077 0.7622845 2
0.4511961 0.1653868 3.66607 307.5214 193.2507 1100.169 490.9451 0.12472 0.7260504 -0.0826719 0.671164 2
0.4417633 0.1185798 3.647333 306.8019 197.8581 1098.219 504.3104 -0.06916566 0.7038938 -0.1973353 0.6788287 2
0.8548245 0.6050068 3.364435 355.8548 143.6936 1240.168 347.399 -0.1105671 0.7569718 0.5679373 -0.3036705 2
0.9127819 0.3429818 3.401966 361.0178 172.7352 1255.312 431.2415 0.05843762 0.9865138 0.1144138 0.1014157 2
0.8905535 0.1046963 3.41965 358.0569 198.5103 1247.062 505.7414 0.05872061 0.7232746 -0.008176846 0.6880109 2
0.87075 0.04412444 3.425586 355.7503 205.0164 1240.474 524.5703 0.1441216 0.6976146 -0.07766516 0.6975178 2
0.6528355 0.1655106 3.560884 329.6426 192.7268 1164.535 489.2483 -0.07249752 -0.6982451 0.7074214 -0.08217493 2
0.6853271 -0.132269 3.735053 329.6935 222.7012 1164.244 576.0424 -0.178642 -0.5079327 0.195941 0.8195722 2
0.7497366 -0.4176336 3.852802 333.831 249.4536 1175.846 653.3472 -0.184775 -0.4768907 0.115771 0.8514872 2
0.6873895 -0.4799173 3.752486 329.6509 256.5972 1164.204 674.0577 0 0 0 0 2
0.7747468 0.1611266 3.501125 343.617 192.8774 1205.077 489.5748 -0.2383845 0.6385497 0.6488895 -0.3381858 2
0.8171204 -0.1421985 3.566106 346.5092 224.3531 1213.413 580.5895 0.0535739 0.8461024 0.1127898 0.5181884 2
0.8511656 -0.4595799 3.664586 347.7344 255.7416 1216.645 671.2434 0.1045282 0.8594967 0.1184092 0.4861262 2
0.7877548 -0.521389 3.564552 343.6025 263.3947 1205.197 693.4037 0 0 0 0 2
0.728802 0.6575422 3.445887 340.162 139.7047 1194.523 335.9147 -0.2832572 0.0415111 0.9510412 -0.1164597 2
0.4375839 0.07206824 3.637431 306.4994 202.4994 1097.437 517.7725 0 0 0 0 2
0.4698927 0.112225 3.619 309.9973 198.4049 1107.594 505.8666 0 0 0 0 2
0.8659531 -0.02500954 3.4443 354.7251 212.4024 1237.508 545.9329 0 0 0 0 2
0.901369 0.03076283 3.417449 359.2839 206.438 1250.683 528.6425 0 0 0 0 2
72057594037944734 0 1 1 0 0 0 0.2471961 -0.2383252 2
25
-0.387201 -0.05403496 2.961099 214.7408 216.41 834.7666 559.0992 0.4557798 -0.05701343 0.8802808 -0.1188276 2
-0.384766 0.2418743 2.881034 213.6927 179.0352 831.9214 450.6993 0.4476836 -0.05592072 0.8844258 -0.1193457 2
-0.3798909 0.5306301 2.789107 212.6189 140.0148 829.2738 337.7772 0.4369697 -0.06553705 0.8833419 -0.1564273 2
-0.3326302 0.6499826 2.820972 219.2917 125.236 848.3054 294.9901 0 0 0 0 2
-0.4304172 0.3858192 2.693917 204.0139 157.2809 805.1432 387.8424 -0.2866399 0.7234381 -0.4137073 0.4725688 2
-0.4339336 0.1272796 2.714383 204.0563 192.5867 805.3087 490.1169 0.03946908 0.9411238 0.006875069 0.3356799 2
-0.36822 -0.07224102 2.786263 214.2289 219.2174 834.5075 567.2477 0.2250257 0.7364373 0.00601962 -0.6379557 2
-0.3700575 -0.127083 2.808011 214.3589 226.2849 834.8156 587.7428 0.1314776 0.7426878 -0.1347235 -0.6426337 2
-0.2981097 0.4353881 2.893557 224.8335 154.6743 863.9844 379.9771 0.2839388 0.7432098 0.5627881 0.2242489 1
-0.2671633 0.2557262 3.116553 231.2118 179.7559 881.2538 452.5558 0.2853988 0.6805304 0.3266041 0.5905555 1
-0.2581481 -0.01978063 3.185815 232.9456 212.0073 886.1906 546.1407 0.122517 0.9550178 -0.01796174 -0.2694587 1
-0.229458 -0.05952364 3.223031 236.5407 216.4835 896.474 559.0983 0.402256 0.8494924 0.1237116 -0.318195 1
-0.4174743 -0.05169832 2.87393 209.4239 216.3172 819.9359 558.8809 0.3790196 -0.5433965 0.6074547 -0.4382502 2
-0.497618 -0.373121 2.919716 200.112 256.5469 793.2654 675.4592 0.03238185 0.950422 -0.1355495 -0.2779852 2
-0.5771378 -0.6777439 2.980165 191.4323 293.2393 768.4659 781.3873 0.05410571 0.9433935 -0.1481987 -0.2917506 2
-0.5500029 -0.756217 2.962994 194.3529 303.4882 777.1974 810.9081 0 0 0 0 2
-0.3476626 -0.05360388 2.980002 219.9058 216.3129 849.6191 558.7677 0.2399694 0.7000011 0.6688294 0.07127756 2
-0.3613364 -0.3393445 3.148574 220.5649 249.1572 850.9557 653.9677 0.1018645 0.298928 0.2436071 0.9170176 2
-0.3918813 -0.6122444 3.258895 218.4803 278.5703 844.7856 738.9781 0.1042677 0.2772751 0.1674781 0.9403179 2
-0.3447265 -0.6813605 3.248184 223.6403 286.6198 859.91 762.1609 0 0 0 0 2
-0.3814697 0.4595191 2.81428 212.8908 149.9362 829.9158 366.4533 0.4385106 -0.05917181 0.8863857 -0.1361158 2
-0.3481399 -0.1767624 2.830068 217.567 232.5726 844.04 605.946 0 0 0 0 2
-0.4126344 -0.1371371 2.798348 208.6103 227.6621 818.2493 591.783 0 0 0 0 2
-0.2210823 -0.06690916 3.229667 237.5418 217.3045 899.3542 561.4713 0 0 0 0 1

get_raw_denoised_data.py

num_bodies=1
num_bodies!=1
根据帧的长度,高宽比,筛选有效帧,并且按照motion进行排序
Denoising based on frame length
Denoising based on spread
Sort bodies based on the motion
num_bodies=1
num_bodies=2
Update joints and colors
start
load raw_skes_data
num of actors
func: get_one_actor_points
func: get_two_actor_points
func: remove_missing_frames
save raw_denoised_joints.pkl,raw_denoised_colors.pkl
end
func: denoising_bodies_data
func: denoising_by_spread
func: denoising_by_spread
func: sort bodies_motions
获得denoised_bodies_data
num of bodies
joints,colors
log missing info
'''
	本文件主要文件为raw_denoised_joints.pkl和raw_denoised_colors.pkl
	denoised_data
	  ├── raw_denoised_colors.pkl
	  │   		├── raw_denoised_colors(list) 
	  │   		│   		├── joints(list)
	  │   		│   		│   		├── joints[:,1,25,2]
	  ├── raw_denoised_joints.pkl
	  │   		├── raw_denoised_joints(list) 
	  │   		│   		├── joints(list)
	  │   		│   		│   		├── joints[:,num_bodies x 25 x 3]
'''
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import os.path as osp
import numpy as np
import pickle
import logging

root_path = './'
raw_data_file = osp.join(root_path, 'raw_data', 'raw_skes_data.pkl')
save_path = osp.join(root_path, 'denoised_data')        # './denoised_data/’

if not osp.exists(save_path):
    os.mkdir(save_path)

rgb_ske_path = osp.join(save_path, 'rgb+ske')           # './denoised_data/rgb+ske'
if not osp.exists(rgb_ske_path):
    os.mkdir(rgb_ske_path)

actors_info_dir = osp.join(save_path, 'actors_info')        # 动作信息,包括bodyID, Interval, Motion
if not osp.exists(actors_info_dir):
    os.mkdir(actors_info_dir)

missing_count = 0                           # 全局变量,统计丢失的数量
noise_len_thres = 11                        # 统计一个动作中帧数少于11的序列
noise_spr_thres1 = 0.8                      # 设置人体高宽比阈值
noise_spr_thres2 = 0.69754                  # 设置噪声比例阈值
noise_mot_thres_lo = 0.089925               # 设置motion阈值
noise_mot_thres_hi = 2                      # 设置motion阈值

# 存储'Skeleton', 'bodyID', 'Motion', 'Length'日志
noise_len_logger = logging.getLogger('noise_length')
noise_len_logger.setLevel(logging.INFO)
noise_len_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_length.log')))
noise_len_logger.info('{:^20}\t{:^17}\t{:^8}\t{}'.format('Skeleton', 'bodyID', 'Motion', 'Length'))

#
noise_spr_logger = logging.getLogger('noise_spread')
noise_spr_logger.setLevel(logging.INFO)
noise_spr_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_spread.log')))
noise_spr_logger.info('{:^20}\t{:^17}\t{:^8}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion', 'Rate'))

noise_mot_logger = logging.getLogger('noise_motion')
noise_mot_logger.setLevel(logging.INFO)
noise_mot_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_motion.log')))
noise_mot_logger.info('{:^20}\t{:^17}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion'))

fail_logger_1 = logging.getLogger('noise_outliers_1')
fail_logger_1.setLevel(logging.INFO)
fail_logger_1.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_1.log')))         # 记录因为一个人却识别成两个人的情况

fail_logger_2 = logging.getLogger('noise_outliers_2')
fail_logger_2.setLevel(logging.INFO)
fail_logger_2.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_2.log')))         # 记录两个人却识别成一个人的情况

missing_skes_logger = logging.getLogger('missing_frames')                                           # 记录去掉帧的数量
missing_skes_logger.setLevel(logging.INFO)
missing_skes_logger.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes.log')))
missing_skes_logger.info('{:^20}\t{}\t{}'.format('Skeleton', 'num_frames', 'num_missing'))

missing_skes_logger1 = logging.getLogger('missing_frames_1')                                        # 记录第一个人去掉帧数量大于第二个人
missing_skes_logger1.setLevel(logging.INFO)
missing_skes_logger1.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_1.log')))
missing_skes_logger1.info('{:^20}\t{}\t{}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1',
                                                              'Actor2', 'Start', 'End'))

missing_skes_logger2 = logging.getLogger('missing_frames_2')                                        # 记录第二个人去掉帧数量大于第一个人
missing_skes_logger2.setLevel(logging.INFO)
missing_skes_logger2.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_2.log')))
missing_skes_logger2.info('{:^20}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1', 'Actor2'))


def denoising_by_length(ske_name, bodies_data):
    """
    Denoising data based on the frame length for each bodyID.
    Filter out the bodyID which length is less or equal than the predefined threshold.
    筛选出帧数少于11的序列,因为都是双人交互,所以有motion
    """
    noise_info = str()
    new_bodies_data = bodies_data.copy()
    for (bodyID, body_data) in new_bodies_data.items():
        length = len(body_data['interval'])
        if length <= noise_len_thres:
            noise_info += 'Filter out: %s, %d (length).\n' % (bodyID, length)
            noise_len_logger.info('{}\t{}\t{:.6f}\t{:^6d}'.format(ske_name, bodyID,
                                                                  body_data['motion'], length))
            del bodies_data[bodyID]
    if noise_info != '':
        noise_info += '\n'

    return bodies_data, noise_info


def get_valid_frames_by_spread(points):
    """
    Find the valid (or reasonable) frames (index) based on the spread of X and Y.
    人体宽高比如果小于4:5,认为是有效帧,因为人不是水桶
    :param points: joints or colors
    :output valid_frames: index of frames list[i]
    """
    num_frames = points.shape[0]
    valid_frames = []
    for i in range(num_frames):
        x = points[i, :, 0]
        y = points[i, :, 1]
        if (x.max() - x.min()) <= noise_spr_thres1 * (y.max() - y.min()):  # 0.8
            valid_frames.append(i)
    return valid_frames


def denoising_by_spread(ske_name, bodies_data):
    """
    Denoising data based on the spread of Y value and X value.
    Filter out the bodyID which the ratio of noisy frames is higher than the predefined
    threshold.

    bodies_data: contains at least 2 bodyIDs
    """
    noise_info = str()
    denoised_by_spr = False  # mark if this sequence has been processed by spread.

    new_bodies_data = bodies_data.copy()
    # for (bodyID, body_data) in bodies_data.items():
    for (bodyID, body_data) in new_bodies_data.items():
        if len(bodies_data) == 1:           # 只包含一个身体数据
            break
        valid_frames = get_valid_frames_by_spread(body_data['joints'].reshape(-1, 25, 3))       # 这个判断有效帧就是简单的把人体的高宽比大于5:4的帧筛选出来
        num_frames = len(body_data['interval'])
        num_noise = num_frames - len(valid_frames)
        if num_noise == 0:
            continue

        ratio = num_noise / float(num_frames)           # 噪声帧占比
        motion = body_data['motion']
        if ratio >= noise_spr_thres2:  # 0.69754
            del bodies_data[bodyID]    # 如果认定的噪声比超过了这个阈值,则从list中删除这个bodyID的数据
            denoised_by_spr = True
            noise_info += 'Filter out: %s (spread rate >= %.2f).\n' % (bodyID, noise_spr_thres2)    # 将噪声比大于0.69754的bodyID删除之后,保存到log中
            noise_spr_logger.info('%s\t%s\t%.6f\t%.6f' % (ske_name, bodyID, motion, ratio))         # \t:代表着四个空格也就是一个tab
        else:  # Update motion
            joints = body_data['joints'].reshape(-1, 25, 3)[valid_frames]                           # 只保留认为有效的帧中骨骼点数据
            body_data['motion'] = min(motion, np.sum(np.var(joints.reshape(-1, 3), axis=0)))
            noise_info += '%s: motion %.6f -> %.6f\n' % (bodyID, motion, body_data['motion'])
            # TODO: Consider removing noisy frames for each bodyID

    if noise_info != '':
        noise_info += '\n'

    return bodies_data, noise_info, denoised_by_spr


def denoising_by_motion(ske_name, bodies_data, bodies_motion):
    """
    Filter out the bodyID which motion is out of the range of predefined interval

    """
    # Sort bodies based on the motion, return a list of tuples
    # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True)
    bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True)

    # Reserve the body data with the largest motion
    denoised_bodies_data = [(bodies_motion[0][0], bodies_data[bodies_motion[0][0]])]
    noise_info = str()

    for (bodyID, motion) in bodies_motion[1:]:
        if (motion < noise_mot_thres_lo) or (motion > noise_mot_thres_hi):
            noise_info += 'Filter out: %s, %.6f (motion).\n' % (bodyID, motion)
            noise_mot_logger.info('{}\t{}\t{:.6f}'.format(ske_name, bodyID, motion))
        else:
            denoised_bodies_data.append((bodyID, bodies_data[bodyID]))
    if noise_info != '':
        noise_info += '\n'

    return denoised_bodies_data, noise_info


def denoising_bodies_data(bodies_data):
    """
    Denoising data based on some heuristic methods, not necessarily correct for all samples.
    根据帧的长度、高宽比、筛选有效帧,并且按照motion进行排序
    Return:
      denoised_bodies_data (list): tuple: (bodyID, body_data).
    """
    ske_name = bodies_data['name']
    bodies_data = bodies_data['data']

    # Step 1: Denoising based on frame length.
    bodies_data, noise_info_len = denoising_by_length(ske_name, bodies_data)

    if len(bodies_data) == 1:  # only has one bodyID left after step 1
        return bodies_data.items(), noise_info_len

    # Step 2: Denoising based on spread.
    bodies_data, noise_info_spr, denoised_by_spr = denoising_by_spread(ske_name, bodies_data)

    if len(bodies_data) == 1:
        return bodies_data.items(), noise_info_len + noise_info_spr

    bodies_motion = dict()  # get body motion
    for (bodyID, body_data) in bodies_data.items():
        bodies_motion[bodyID] = body_data['motion']
    # Sort bodies based on the motion
    # 按照motion排序所有的bodies中的motion,并将所有数据存入denoised_bodies_data中
    # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True)
    bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True)
    denoised_bodies_data = list()
    for (bodyID, _) in bodies_motion:
        denoised_bodies_data.append((bodyID, bodies_data[bodyID]))

    return denoised_bodies_data, noise_info_len + noise_info_spr

    # TODO: Consider denoising further by integrating motion method

    # if denoised_by_spr:  # this sequence has been denoised by spread
    #     bodies_motion = sorted(bodies_motion.items(), lambda x, y: cmp(x[1], y[1]), reverse=True)
    #     denoised_bodies_data = list()
    #     for (bodyID, _) in bodies_motion:
    #         denoised_bodies_data.append((bodyID, bodies_data[bodyID]))
    #     return denoised_bodies_data, noise_info

    # Step 3: Denoising based on motion
    # bodies_data, noise_info = denoising_by_motion(ske_name, bodies_data, bodies_motion)

    # return bodies_data, noise_info


def get_one_actor_points(body_data, num_frames):
    """
    Get joints and colors for only one actor.
    For joints, each frame contains 75(25 x 3) X-Y-Z coordinates.
    For colors, each frame contains 25 x 2 (X, Y) coordinates.
    """
    joints = np.zeros((num_frames, 75), dtype=np.float32)
    colors = np.ones((num_frames, 1, 25, 2), dtype=np.float32) * np.nan
    start, end = body_data['interval'][0], body_data['interval'][-1]
    joints[start:end + 1] = body_data['joints'].reshape(-1, 75)
    colors[start:end + 1, 0] = body_data['colors']

    return joints, colors


def remove_missing_frames(ske_name, joints, colors):
    """
    Cut off missing frames which all joints positions are 0s

    For the sequence with 2 actors' data, also record the number of missing frames for
    actor1 and actor2, respectively (for debug).
    """
    num_frames = joints.shape[0]
    num_bodies = colors.shape[1]  # 1 or 2

    if num_bodies == 2:  # DEBUG

        missing_indices_1 = np.where(joints[:, :75].sum(axis=1) == 0)[0]        # 此处代表某一个时间,没有第一个人
        missing_indices_2 = np.where(joints[:, 75:].sum(axis=1) == 0)[0]        # 此处代表某一个时间,没有第二个人
        cnt1 = len(missing_indices_1)                                           # 没有第一个人多长时间
        cnt2 = len(missing_indices_2)                                           # 没有第二个人多长时间

        start = 1 if 0 in missing_indices_1 else 0                              # 如果第一帧丢失,start = 1,否则start = 0
        end = 1 if num_frames - 1 in missing_indices_1 else 0                   # 如果最后一帧丢失,end = 1,否则end = 0
        if max(cnt1, cnt2) > 0:
            if cnt1 > cnt2:
                info = '{}\t{:^10d}\t{:^6d}\t{:^6d}\t{:^5d}\t{:^3d}'.format(ske_name, num_frames,
                                                                            cnt1, cnt2, start, end)
                missing_skes_logger1.info(info)
            else:
                info = '{}\t{:^10d}\t{:^6d}\t{:^6d}'.format(ske_name, num_frames, cnt1, cnt2)
                missing_skes_logger2.info(info)

    # Find valid frame indices that the data is not missing or lost
    # For two-subjects action, this means both data of actor1 and actor2 is missing.
    valid_indices = np.where(joints.sum(axis=1) != 0)[0]  # 0-based index       # 这一时刻有人
    missing_indices = np.where(joints.sum(axis=1) == 0)[0]                      # 这一时刻没有人
    num_missing = len(missing_indices)

    if num_missing > 0:  # Update joints and colors
        joints = joints[valid_indices]
        colors[missing_indices] = np.nan
        global missing_count
        missing_count += 1
        missing_skes_logger.info('{}\t{:^10d}\t{:^11d}'.format(ske_name, num_frames, num_missing))

    return joints, colors


def get_bodies_info(bodies_data):
    bodies_info = '{:^17}\t{}\t{:^8}\n'.format('bodyID', 'Interval', 'Motion')
    for (bodyID, body_data) in bodies_data.items():
        start, end = body_data['interval'][0], body_data['interval'][-1]
        bodies_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), body_data['motion'])

    return bodies_info + '\n'


def get_two_actors_points(bodies_data):
    """
    Get the first and second actor's joints positions and colors locations.
    这部分主要是去除掉
    # Arguments:
        bodies_data (dict): 3 key-value pairs: 'name', 'data', 'num_frames'.
        bodies_data['data'] is also a dict, while the key is bodyID, the value is
        the corresponding body_data which is also a dict with 4 keys:
          - joints: raw 3D joints positions. Shape: (num_frames x 25, 3)
          - colors: raw 2D color locations. Shape: (num_frames, 25, 2)
          - interval: a list which records the frame indices.
          - motion: motion amount

    # Return:
        joints, colors.
    """
    ske_name = bodies_data['name']
    label = int(ske_name[-2:])
    num_frames = bodies_data['num_frames']
    # bodies_info: bodyID [start, end] motion
    bodies_info = get_bodies_info(bodies_data['data'])

    bodies_data, noise_info = denoising_bodies_data(bodies_data)  # Denoising data
    bodies_info += noise_info

    bodies_data = list(bodies_data)
    if len(bodies_data) == 1:  # Only left one actor after denoising
        if label >= 50:  # DEBUG: Denoising failed for two-subjects action
            fail_logger_2.info(ske_name)

        bodyID, body_data = bodies_data[0]
        joints, colors = get_one_actor_points(body_data, num_frames)
        bodies_info += 'Main actor: %s' % bodyID
    else:
        if label < 50:  # DEBUG: Denoising failed for one-subject action
            fail_logger_1.info(ske_name)

        joints = np.zeros((num_frames, 150), dtype=np.float32)
        colors = np.ones((num_frames, 2, 25, 2), dtype=np.float32) * np.nan

        bodyID, actor1 = bodies_data[0]  # the 1st actor with largest motion
        start1, end1 = actor1['interval'][0], actor1['interval'][-1]
        joints[start1:end1 + 1, :75] = actor1['joints'].reshape(-1, 75)
        colors[start1:end1 + 1, 0] = actor1['colors']
        actor1_info = '{:^17}\t{}\t{:^8}\n'.format('Actor1', 'Interval', 'Motion') + \
                      '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start1, end1]), actor1['motion'])
        del bodies_data[0]         # 其实主要信息已经保存到了joints和colors里面,所以就可以删除掉第一个数据

        actor2_info = '{:^17}\t{}\t{:^8}\n'.format('Actor2', 'Interval', 'Motion')
        start2, end2 = [0, 0]  # initial interval for actor2 (virtual)

        while len(bodies_data) > 0:
            bodyID, actor = bodies_data[0]
            start, end = actor['interval'][0], actor['interval'][-1]
            if min(end1, end) - max(start1, start) <= 0:  # no overlap with actor1
                joints[start:end + 1, :75] = actor['joints'].reshape(-1, 75)
                colors[start:end + 1, 0] = actor['colors']
                actor1_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion'])
                # Update the interval of actor1
                start1 = min(start, start1)
                end1 = max(end, end1)
            elif min(end2, end) - max(start2, start) <= 0:  # no overlap with actor2 如果与第一个重合,则数据输入其中,并且设置起始位置为0,结束位置为end
                joints[start:end + 1, 75:] = actor['joints'].reshape(-1, 75)
                colors[start:end + 1, 1] = actor['colors']
                actor2_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion'])
                # Update the interval of actor2
                start2 = min(start, start2)
                end2 = max(end, end2)
            del bodies_data[0]

        bodies_info += ('\n' + actor1_info + '\n' + actor2_info)

    with open(osp.join(actors_info_dir, ske_name + '.txt'), 'w') as fw:
        fw.write(bodies_info + '\n')

    return joints, colors


def get_raw_denoised_data():
    """
    Get denoised data (joints positions and color locations) from raw skeleton sequences.

    For each frame of a skeleton sequence, an actor's 3D positions of 25 joints represented
    by an 2D array (shape: 25 x 3) is reshaped into a 75-dim vector by concatenating each
    3-dim (x, y, z) coordinates along the row dimension in joint order. Each frame contains
    two actor's joints positions constituting a 150-dim vector. If there is only one actor,
    then the last 75 values are filled with zeros. Otherwise, select the main actor and the
    second actor based on the motion amount. Each 150-dim vector as a row vector is put into
    a 2D numpy array where the number of rows equals the number of valid frames. All such
    2D arrays are put into a list and finally the list is serialized into a cPickle file.

    For the skeleton sequence which contains two or more actors (mostly corresponds to the
    last 11 classes), the filename and actors' information are recorded into log files.
    For better understanding, also generate RGB+skeleton videos for visualization.
    """

    with open(raw_data_file, 'rb') as fr:  # load raw skeletons data
        raw_skes_data = pickle.load(fr)

    num_skes = len(raw_skes_data)
    print('Found %d available skeleton sequences.' % num_skes)

    raw_denoised_joints = []
    raw_denoised_colors = []
    frames_cnt = []

    for (idx, bodies_data) in enumerate(raw_skes_data):
        ske_name = bodies_data['name']
        print('Processing %s' % ske_name)
        num_bodies = len(bodies_data['data'])

        if num_bodies == 1:  # only 1 actor
            num_frames = bodies_data['num_frames']
            body_data = list(bodies_data['data'].values())[0]
            joints, colors = get_one_actor_points(body_data, num_frames)
        else:  # more than 1 actor, select two main actors
            joints, colors = get_two_actors_points(bodies_data)
            # Remove missing frames
            joints, colors = remove_missing_frames(ske_name, joints, colors)
            num_frames = joints.shape[0]  # Update
            # Visualize selected actors' skeletons on RGB videos.

        raw_denoised_joints.append(joints)
        raw_denoised_colors.append(colors)
        frames_cnt.append(num_frames)

        if (idx + 1) % 1000 == 0:
            print('Processed: %.2f%% (%d / %d), ' % \
                  (100.0 * (idx + 1) / num_skes, idx + 1, num_skes) + \
                  'Missing count: %d' % missing_count)

    raw_skes_joints_pkl = osp.join(save_path, 'raw_denoised_joints.pkl')
    with open(raw_skes_joints_pkl, 'wb') as f:
        pickle.dump(raw_denoised_joints, f, pickle.HIGHEST_PROTOCOL)

    raw_skes_colors_pkl = osp.join(save_path, 'raw_denoised_colors.pkl')
    with open(raw_skes_colors_pkl, 'wb') as f:
        pickle.dump(raw_denoised_colors, f, pickle.HIGHEST_PROTOCOL)

    frames_cnt = np.array(frames_cnt, dtype=np.int)
    np.savetxt(osp.join(save_path, 'frames_cnt.txt'), frames_cnt, fmt='%d')

    print('Saved raw denoised positions of {} frames into {}'.format(np.sum(frames_cnt),
                                                                     raw_skes_joints_pkl))
    print('Found %d files that have missing data' % missing_count)

if __name__ == '__main__':
    get_raw_denoised_data()

seq_transformation.py

所有关节点坐标减去脊柱中心坐标,相当于归一化
CS根据表演者performer来划分训练集和测试集,CV是根据镜头划分的测试集和训练集
将每个labels分为60类,是哪个类就给这个类给1
start
load camera,performer,label
每个文件多少帧,skeleton_name
raw_skes_joints_pkl
func: seq_translation
func: align_frames
func:split_dataset
end
减去中心坐标之后,对原先0的点重新补零
对于长度<300的帧补0
对只有一个人的帧,另一半补0
func: get_indices
func: one_hot_vector
'''
	本文件主要文件为NTU60_CS.npz和NTU60_CV.npz
	NTU60_CS.npz
	  ├── train_x
	  │   		├── skes_joints[train_indices]
	  ├── train_y
	  │   		├── one_hot_vector(train_labels)
	  ├── test_x
	  │   		├── skes_joints[test_indices] 
	  ├── test_y
	  │   		├── one_hot_vector(test_labels) 
'''
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import os.path as osp
import numpy as np
import pickle
import logging
import h5py
from sklearn.model_selection import train_test_split

root_path = './'
stat_path = osp.join(root_path, 'statistics')
setup_file = osp.join(stat_path, 'setup.txt')
camera_file = osp.join(stat_path, 'camera.txt')
performer_file = osp.join(stat_path, 'performer.txt')                       # 表演者
replication_file = osp.join(stat_path, 'replication.txt')
label_file = osp.join(stat_path, 'label.txt')
skes_name_file = osp.join(stat_path, 'skes_available_name.txt')

denoised_path = osp.join(root_path, 'denoised_data')
raw_skes_joints_pkl = osp.join(denoised_path, 'raw_denoised_joints.pkl')
frames_file = osp.join(denoised_path, 'frames_cnt.txt')

save_path = './'


# if not osp.exists(save_path):
#     os.mkdir(save_path)


def remove_nan_frames(ske_name, ske_joints, nan_logger):
    num_frames = ske_joints.shape[0]
    valid_frames = []

    for f in range(num_frames):
        if not np.any(np.isnan(ske_joints[f])):
            valid_frames.append(f)
        else:
            nan_indices = np.where(np.isnan(ske_joints[f]))[0]
            nan_logger.info('{}\t{:^5}\t{}'.format(ske_name, f + 1, nan_indices))

    return ske_joints[valid_frames]

def seq_translation(skes_joints):
    for idx, ske_joints in enumerate(skes_joints):
        num_frames = ske_joints.shape[0]
        num_bodies = 1 if ske_joints.shape[1] == 75 else 2
        if num_bodies == 2:
            missing_frames_1 = np.where(ske_joints[:, :75].sum(axis=1) == 0)[0]
            missing_frames_2 = np.where(ske_joints[:, 75:].sum(axis=1) == 0)[0]
            cnt1 = len(missing_frames_1)
            cnt2 = len(missing_frames_2)

        i = 0  # get the "real" first frame of actor1
        while i < num_frames:
            if np.any(ske_joints[i, :75] != 0):
                break
            i += 1

        origin = np.copy(ske_joints[i, 3:6])  # new origin: joint-2 脊柱中间的坐标

        for f in range(num_frames):
            if num_bodies == 1:
                ske_joints[f] -= np.tile(origin, 25)            # 所有的点坐标减去脊柱点坐标
            else:  # for 2 actors
                ske_joints[f] -= np.tile(origin, 50)

        if (num_bodies == 2) and (cnt1 > 0):
            ske_joints[missing_frames_1, :75] = np.zeros((cnt1, 75), dtype=np.float32)

        if (num_bodies == 2) and (cnt2 > 0):
            ske_joints[missing_frames_2, 75:] = np.zeros((cnt2, 75), dtype=np.float32)

        skes_joints[idx] = ske_joints  # Update

    return skes_joints


def frame_translation(skes_joints, skes_name, frames_cnt):
    nan_logger = logging.getLogger('nan_skes')
    nan_logger.setLevel(logging.INFO)
    nan_logger.addHandler(logging.FileHandler("./nan_frames.log"))
    nan_logger.info('{}\t{}\t{}'.format('Skeleton', 'Frame', 'Joints'))

    for idx, ske_joints in enumerate(skes_joints):
        num_frames = ske_joints.shape[0]
        # Calculate the distance between spine base (joint-1) and spine (joint-21)
        j1 = ske_joints[:, 0:3]
        j21 = ske_joints[:, 60:63]
        dist = np.sqrt(((j1 - j21) ** 2).sum(axis=1))

        for f in range(num_frames):
            origin = ske_joints[f, 3:6]  # new origin: middle of the spine (joint-2)
            if (ske_joints[f, 75:] == 0).all():
                ske_joints[f, :75] = (ske_joints[f, :75] - np.tile(origin, 25)) / \
                                      dist[f] + np.tile(origin, 25)
            else:
                ske_joints[f] = (ske_joints[f] - np.tile(origin, 50)) / \
                                 dist[f] + np.tile(origin, 50)

        ske_name = skes_name[idx]
        ske_joints = remove_nan_frames(ske_name, ske_joints, nan_logger)
        frames_cnt[idx] = num_frames  # update valid number of frames
        skes_joints[idx] = ske_joints

    return skes_joints, frames_cnt


def align_frames(skes_joints, frames_cnt):
    """
    Align all sequences with the same frame length.

    """
    num_skes = len(skes_joints)
    max_num_frames = frames_cnt.max()  # 300
    aligned_skes_joints = np.zeros((num_skes, max_num_frames, 150), dtype=np.float32)

    for idx, ske_joints in enumerate(skes_joints):
        num_frames = ske_joints.shape[0]
        num_bodies = 1 if ske_joints.shape[1] == 75 else 2
        if num_bodies == 1:
            aligned_skes_joints[idx, :num_frames] = np.hstack((ske_joints,
                                                               np.zeros_like(ske_joints)))
        else:
            aligned_skes_joints[idx, :num_frames] = ske_joints

    return aligned_skes_joints


def one_hot_vector(labels):
    num_skes = len(labels)
    labels_vector = np.zeros((num_skes, 60))
    for idx, l in enumerate(labels):
        labels_vector[idx, l] = 1

    return labels_vector


def split_train_val(train_indices, method='sklearn', ratio=0.05):
    """
    Get validation set by splitting data randomly from training set with two methods.
    In fact, I thought these two methods are equal as they got the same performance.

    """
    if method == 'sklearn':
        return train_test_split(train_indices, test_size=ratio, random_state=10000)
    else:
        np.random.seed(10000)
        np.random.shuffle(train_indices)
        val_num_skes = int(np.ceil(0.05 * len(train_indices)))
        val_indices = train_indices[:val_num_skes]
        train_indices = train_indices[val_num_skes:]
        return train_indices, val_indices


def split_dataset(skes_joints, label, performer, camera, evaluation, save_path):
    train_indices, test_indices = get_indices(performer, camera, evaluation)        # CS根据表演者performer来划分训练集和测试集,CV是根据镜头划分的测试集和训练集
    m = 'sklearn'  # 'sklearn' or 'numpy'
    # Select validation set from training set
    # train_indices, val_indices = split_train_val(train_indices, m)

    # Save labels and num_frames for each sequence of each data set
    train_labels = label[train_indices]
    test_labels = label[test_indices]

    train_x = skes_joints[train_indices]
    train_y = one_hot_vector(train_labels)
    test_x = skes_joints[test_indices]
    test_y = one_hot_vector(test_labels)

    save_name = 'NTU60_%s.npz' % evaluation
    np.savez(save_name, x_train=train_x, y_train=train_y, x_test=test_x, y_test=test_y)

    # Save data into a .h5 file
    # h5file = h5py.File(osp.join(save_path, 'NTU_%s.h5' % (evaluation)), 'w')
    # Training set
    # h5file.create_dataset('x', data=skes_joints[train_indices])
    # train_one_hot_labels = one_hot_vector(train_labels)
    # h5file.create_dataset('y', data=train_one_hot_labels)
    # Validation set
    # h5file.create_dataset('valid_x', data=skes_joints[val_indices])
    # val_one_hot_labels = one_hot_vector(val_labels)
    # h5file.create_dataset('valid_y', data=val_one_hot_labels)
    # Test set
    # h5file.create_dataset('test_x', data=skes_joints[test_indices])
    # test_one_hot_labels = one_hot_vector(test_labels)
    # h5file.create_dataset('test_y', data=test_one_hot_labels)

    # h5file.close()


def get_indices(performer, camera, evaluation='CS'):
    test_indices = np.empty(0)
    train_indices = np.empty(0)

    if evaluation == 'CS':  # Cross Subject (Subject IDs)
        train_ids = [1,  2,  4,  5,  8,  9,  13, 14, 15, 16,
                     17, 18, 19, 25, 27, 28, 31, 34, 35, 38]
        test_ids = [3,  6,  7,  10, 11, 12, 20, 21, 22, 23,
                    24, 26, 29, 30, 32, 33, 36, 37, 39, 40]

        # Get indices of test data
        for idx in test_ids:
            temp = np.where(performer == idx)[0]  # 0-based index
            test_indices = np.hstack((test_indices, temp)).astype(np.int)

        # Get indices of training data
        for train_id in train_ids:
            temp = np.where(performer == train_id)[0]  # 0-based index
            train_indices = np.hstack((train_indices, temp)).astype(np.int)
    else:  # Cross View (Camera IDs)
        train_ids = [2, 3]
        test_ids = 1
        # Get indices of test data
        temp = np.where(camera == test_ids)[0]  # 0-based index
        test_indices = np.hstack((test_indices, temp)).astype(np.int)

        # Get indices of training data
        for train_id in train_ids:
            temp = np.where(camera == train_id)[0]  # 0-based index
            train_indices = np.hstack((train_indices, temp)).astype(np.int)

    return train_indices, test_indices


if __name__ == '__main__':
    camera = np.loadtxt(camera_file, dtype=np.int)  # camera id: 1, 2, 3
    performer = np.loadtxt(performer_file, dtype=np.int)  # subject id: 1~40
    label = np.loadtxt(label_file, dtype=np.int) - 1  # action label: 0~59

    frames_cnt = np.loadtxt(frames_file, dtype=np.int)  # frames_cnt
    skes_name = np.loadtxt(skes_name_file, dtype=np.string_)

    with open(raw_skes_joints_pkl, 'rb') as fr:
        skes_joints = pickle.load(fr)  # a list

    skes_joints = seq_translation(skes_joints)

    skes_joints = align_frames(skes_joints, frames_cnt)  # aligned to the same frame length

    evaluations = ['CS', 'CV']
    for evaluation in evaluations:
        split_dataset(skes_joints, label, performer, camera, evaluation, save_path)

你可能感兴趣的:(python,python)