Dex-Net生成数据集源码分析,代码注解

生成gqcnn直接使用的数据集的代码在Dex-Net模块的tool目录下的generate_gqcnn_dataset.py文件中,其中最重要的两个步骤为

  1. 预处理抓取姿态并进行碰撞检测
  2. 渲染图像并存储

下面就是这些代码并进行了中文注解加以理解

第一部分

# 1.对于每一个稳定姿态预计算一组有效抓取的集合
# 约束1:垂直于桌面
# 约束2:抓取路径方向上无碰撞

# 导入已存在的抓取缓存
grasp_cache_filename = os.path.join(output_dir, CACHE_FILENAME)
if os.path.exists(grasp_cache_filename):
    logging.info('Loading grasp candidates from file')
    # pkl即为cPickle,用来保存和加载python对象
    candidate_grasps_dict = pkl.load(open(grasp_cache_filename, 'rb'))
# 如果不存在缓存,则重新读取并计算抓取数据并施加约束
else:        
    # 存储抓取的字典
    candidate_grasps_dict = {}
    
    # 对每个数据集和每个对象进行迭代计算
    for dataset in datasets:
        logging.info('Reading dataset %s' %(dataset.name))
        for obj in dataset:
            # 如果目标不需要计算则跳过
            if obj.key not in target_object_keys[dataset.name]:
                continue

            # 初始化候选抓取存储区为一个字典
            candidate_grasps_dict[obj.key] = {}

            # 设置碰撞检测器,碰撞检测器定义在Dex-Net中
            collision_checker = GraspCollisionChecker(gripper)
            collision_checker.set_graspable_object(obj)

            # 读取所有稳定姿态的mesh格式
            stable_poses = dataset.stable_poses(obj.key)
            for i, stable_pose in enumerate(stable_poses):
                # 如果稳定姿态有效则渲染一副图像
                if stable_pose.p > stable_pose_min_p:
                    candidate_grasps_dict[obj.key][stable_pose.id] = []

                    # 在碰撞检测器中设置桌面
                    T_obj_stp = stable_pose.T_obj_table.as_frames('obj', 'stp')
                    T_obj_table = obj.mesh.get_T_surface_obj(T_obj_stp, delta=table_offset).as_frames('obj', 'table')
                    T_table_obj = T_obj_table.inverse()
                    collision_checker.set_table(table_mesh_filename, T_table_obj)

                    # 读取所有存储的抓取姿势
                    grasps = dataset.grasps(obj.key, gripper=gripper.name)
                    logging.info('Aligning %d grasps for object %s in stable %s' %(len(grasps), obj.key, stable_pose.id))

                    # 得到所有对齐的抓取姿势,把所有抓取姿势与稳定姿态对齐
                    aligned_grasps = [grasp.perpendicular_table(stable_pose) for grasp in grasps]

                    # 检查抓取姿势是否有效
                    logging.info('Checking collisions for %d grasps for object %s in stable %s' %(len(grasps), obj.key, stable_pose.id))
                    for aligned_grasp in aligned_grasps:
                        # 检查抓取轴与桌面法线夹角,并跳过未对齐(大于最大夹角)的抓取姿势
                        _, grasp_approach_table_angle, _ = aligned_grasp.grasp_angles_from_stp_z(stable_pose)
                        perpendicular_table = (np.abs(grasp_approach_table_angle) < max_grasp_approach_table_angle)
                        if not perpendicular_table: 
                            continue

                        # 检查每个抓取路径上是否有碰撞
                        collision_free = False
                        # 对每个抓取补偿角度,这里只有10度,即原始抓取角度旋转10度再检测是否碰撞
                        for phi_offset in phi_offsets:
                            rotated_grasp = aligned_grasp.grasp_y_axis_offset(phi_offset)
                            collides = collision_checker.collides_along_approach(rotated_grasp, approach_dist, delta_approach)
                            # 这里搞不懂,为什么只要一个补偿角度无碰撞就可以定义为无碰撞??
                            if not collides:
                                collision_free = True
                                break
                
                        # 把生成的抓取位姿存储在一个缓冲数据结构中
                        candidate_grasps_dict[obj.key][stable_pose.id].append(GraspInfo(aligned_grasp, collision_free))

                        # 如果指定了显示则显示该抓取位姿
                        if collision_free and config['vis']['candidate_grasps']:
                            logging.info('Grasp %d' %(aligned_grasp.id))
                            vis.figure()
                            vis.gripper_on_object(gripper, aligned_grasp, obj, stable_pose.T_obj_world)
                            vis.show()
                            
    # 存储缓冲数据
    logging.info('Saving to file')
    pkl.dump(candidate_grasps_dict, open(grasp_cache_filename, 'wb'))

第二部分

 # 2. 对于数据集中的每一个抓取姿势,渲染一个图像数据集,并且关联夹爪位姿与图像坐标

    # 设置变量
    # 物体种类
    obj_category_map = {}
    # 姿势种类
    pose_category_map = {}

    cur_pose_label = 0
    cur_obj_label = 0
    cur_image_label = 0
                
    # 渲染数据集中每个物体的每个稳定位姿的图像
    render_modes = [RenderMode.SEGMASK, RenderMode.DEPTH_SCENE]
    for dataset in datasets:
        logging.info('Generating data for dataset %s' %(dataset.name))
        
        # 对每个物体进行迭代
        object_keys = dataset.object_keys
        for obj_key in object_keys:
            obj = dataset[obj_key]
            if obj.key not in target_object_keys[dataset.name]:
                continue

            # 读入物体mesh的一个稳定姿态
            stable_poses = dataset.stable_poses(obj.key)
            for i, stable_pose in enumerate(stable_poses):

                # 如果稳定姿态有效则渲染图片
                if stable_pose.p > stable_pose_min_p:
                    # log progress
                    logging.info('Rendering images for object %s in %s' %(obj.key, stable_pose.id))

                    # add to category maps
                    if obj.key not in obj_category_map.keys():
                        obj_category_map[obj.key] = cur_obj_label
                    pose_category_map['%s_%s' %(obj.key, stable_pose.id)] = cur_pose_label

                    # 读取候选的抓取和判据,candidate_grasps_dict是第一步中生成的候选抓取字典
                    candidate_grasp_info = candidate_grasps_dict[obj.key][stable_pose.id]
                    candidate_grasps = [g.grasp for g in candidate_grasp_info]
                    # grasp_metrics这个方法可以直接生成每个抓取姿势的判据值,很关键
                    grasp_metrics = dataset.grasp_metrics(obj.key, candidate_grasps, gripper=gripper.name)

                    # 计算物体相对于桌面的位姿
                    T_obj_stp = stable_pose.T_obj_table.as_frames('obj', 'stp')
                    T_obj_stp = obj.mesh.get_T_surface_obj(T_obj_stp)

                    # 利用随机变量采样图片
                    T_table_obj = RigidTransform(from_frame='table', to_frame='obj')
                    # SceneObject类只是把背景和前景两个参数打包一下而已
                    scene_objs = {'table': SceneObject(table_mesh, T_table_obj)}
                    # 这个函数定义在meshpy模块中,具体后面解释
                    # env_rv_params这个是相机参数
                    urv = UniformPlanarWorksurfaceImageRandomVariable(obj.mesh,
                                                                      render_modes,
                                                                      'camera',
                                                                      env_rv_params,
                                                                      stable_pose=stable_pose,
                                                                      scene_objs=scene_objs)
                    
                    render_start = time.time()
                    # urv.rvs定义在父类RandomVariable,在aotulab_core包中,其实际也是调用了子类中的sample方法
                    # 具体的sample方法后面具体分析
                    # 这里的size是采样的个数
                    render_samples = urv.rvs(size=image_samples_per_stable_pose)
                    render_stop = time.time()
                    logging.info('Rendering images took %.3f sec' %(render_stop - render_start))

                    # tally total amount of data
                    num_grasps = len(candidate_grasps)
                    num_images = image_samples_per_stable_pose 
                    num_save = num_images * num_grasps
                    logging.info('Saving %d datapoints' %(num_save))

                    # 对于物体上的每个候选抓取位置计算夹爪到图像空间的投影
                    for render_sample in render_samples:
                        # 读取图片
                        binary_im = render_sample.renders[RenderMode.SEGMASK].image
                        depth_im_table = render_sample.renders[RenderMode.DEPTH_SCENE].image
                        # 读取相机参数
                        T_stp_camera = render_sample.camera.object_to_camera_pose
                        shifted_camera_intr = render_sample.camera.camera_intr

                        # 读取像素坐标
                        cx = depth_im_table.center[1]
                        cy = depth_im_table.center[0]

                        # compute intrinsics for virtual camera of the final
                        # cropped and rescaled images
                        # 计算虚拟相机的内参数,为了最终裁剪和缩放图像
                        camera_intr_scale = float(im_final_height) / float(im_crop_height)
                        cropped_camera_intr = shifted_camera_intr.crop(im_crop_height, im_crop_width, cy, cx)
                        final_camera_intr = cropped_camera_intr.resize(camera_intr_scale)

                        # 为每个抓取点生成小图,candidate_grasp_info是第一步生成的抓取和碰撞字典
                        for grasp_info in candidate_grasp_info:
                            # 读取抓取点信息和碰撞信息
                            grasp = grasp_info.grasp
                            collision_free = grasp_info.collision_free
                            
                            # 获取抓取位姿
                            # T_obj_camera为物体坐标系到相机坐标系的变换矩阵
                            T_obj_camera = T_stp_camera * T_obj_stp.as_frames('obj', T_stp_camera.from_frame)
                            # 把抓取点投影到相机空间,Grasp2D对象在gqcnn包内,Grasp2D对象后面进行分析
                            grasp_2d = grasp.project_camera(T_obj_camera, shifted_camera_intr)

                            # 对图片进行平移和旋转,对齐到夹爪中心点,且夹取轴在x轴
                            dx = cx - grasp_2d.center.x
                            dy = cy - grasp_2d.center.y
                            translation = np.array([dy, dx])

                            binary_im_tf = binary_im.transform(translation, grasp_2d.angle)
                            depth_im_tf_table = depth_im_table.transform(translation, grasp_2d.angle)

                            # 对图像进行裁剪到目标大小
                            binary_im_tf = binary_im_tf.crop(im_crop_height, im_crop_width)
                            depth_im_tf_table = depth_im_tf_table.crop(im_crop_height, im_crop_width)

                            # 缩放图像到最终目标大小
                            binary_im_tf = binary_im_tf.resize((im_final_height, im_final_width), interp='nearest')
                            depth_im_tf_table = depth_im_tf_table.resize((im_final_height, im_final_width))

                            # 组成一个抓取位姿矩阵
                            # np.r_用于把几个矩阵竖着叠起来
                            hand_pose = np.r_[grasp_2d.center.y,
                                              grasp_2d.center.x,
                                              grasp_2d.depth,
                                              grasp_2d.angle,
                                              grasp_2d.center.y - shifted_camera_intr.cy,
                                              grasp_2d.center.x - shifted_camera_intr.cx,
                                              grasp_2d.width_px]
         

                            # 存储所有数据到缓冲区
                            tensor_datapoint['depth_ims_tf_table'] = depth_im_tf_table.raw_data
                            tensor_datapoint['obj_masks'] = binary_im_tf.raw_data
                            tensor_datapoint['hand_poses'] = hand_pose
                            tensor_datapoint['collision_free'] = collision_free
                            tensor_datapoint['obj_labels'] = cur_obj_label
                            tensor_datapoint['pose_labels'] = cur_pose_label
                            tensor_datapoint['image_labels'] = cur_image_label

                            for metric_name, metric_val in grasp_metrics[grasp.id].iteritems():
                                coll_free_metric = (1 * collision_free) * metric_val
                                tensor_datapoint[metric_name] = coll_free_metric
                            tensor_dataset.add(tensor_datapoint)

                        # 更新图像计数标签
                        cur_image_label += 1

                    # 更新稳定姿态计数标记
                    cur_pose_label += 1

                    # 进行显示的垃圾回收
                    gc.collect()

            # 更新物体计数标签
            cur_obj_label += 1

            # 进行显示的垃圾回收
            gc.collect()

    # 输出内存里的最终数据
    # 这里使用np.savez_savez_compressed和np.savez存有区别,前者经过压缩后者没有压缩
    tensor_dataset.flush()

    # 存储物体种类和姿态种类映射文件,用来对应存储的最终数据和原始物体名字直接的映射
    obj_cat_filename = os.path.join(output_dir, 'object_category_map.json')
    json.dump(obj_category_map, open(obj_cat_filename, 'w'))
    pose_cat_filename = os.path.join(output_dir, 'pose_category_map.json')
    json.dump(pose_category_map, open(pose_cat_filename, 'w'))

你可能感兴趣的:(Dex-Net学习)