生成gqcnn直接使用的数据集的代码在Dex-Net模块的tool目录下的generate_gqcnn_dataset.py文件中,其中最重要的两个步骤为
下面就是这些代码并进行了中文注解加以理解
# 1.对于每一个稳定姿态预计算一组有效抓取的集合
# 约束1:垂直于桌面
# 约束2:抓取路径方向上无碰撞
# 导入已存在的抓取缓存
grasp_cache_filename = os.path.join(output_dir, CACHE_FILENAME)
if os.path.exists(grasp_cache_filename):
logging.info('Loading grasp candidates from file')
# pkl即为cPickle,用来保存和加载python对象
candidate_grasps_dict = pkl.load(open(grasp_cache_filename, 'rb'))
# 如果不存在缓存,则重新读取并计算抓取数据并施加约束
else:
# 存储抓取的字典
candidate_grasps_dict = {}
# 对每个数据集和每个对象进行迭代计算
for dataset in datasets:
logging.info('Reading dataset %s' %(dataset.name))
for obj in dataset:
# 如果目标不需要计算则跳过
if obj.key not in target_object_keys[dataset.name]:
continue
# 初始化候选抓取存储区为一个字典
candidate_grasps_dict[obj.key] = {}
# 设置碰撞检测器,碰撞检测器定义在Dex-Net中
collision_checker = GraspCollisionChecker(gripper)
collision_checker.set_graspable_object(obj)
# 读取所有稳定姿态的mesh格式
stable_poses = dataset.stable_poses(obj.key)
for i, stable_pose in enumerate(stable_poses):
# 如果稳定姿态有效则渲染一副图像
if stable_pose.p > stable_pose_min_p:
candidate_grasps_dict[obj.key][stable_pose.id] = []
# 在碰撞检测器中设置桌面
T_obj_stp = stable_pose.T_obj_table.as_frames('obj', 'stp')
T_obj_table = obj.mesh.get_T_surface_obj(T_obj_stp, delta=table_offset).as_frames('obj', 'table')
T_table_obj = T_obj_table.inverse()
collision_checker.set_table(table_mesh_filename, T_table_obj)
# 读取所有存储的抓取姿势
grasps = dataset.grasps(obj.key, gripper=gripper.name)
logging.info('Aligning %d grasps for object %s in stable %s' %(len(grasps), obj.key, stable_pose.id))
# 得到所有对齐的抓取姿势,把所有抓取姿势与稳定姿态对齐
aligned_grasps = [grasp.perpendicular_table(stable_pose) for grasp in grasps]
# 检查抓取姿势是否有效
logging.info('Checking collisions for %d grasps for object %s in stable %s' %(len(grasps), obj.key, stable_pose.id))
for aligned_grasp in aligned_grasps:
# 检查抓取轴与桌面法线夹角,并跳过未对齐(大于最大夹角)的抓取姿势
_, grasp_approach_table_angle, _ = aligned_grasp.grasp_angles_from_stp_z(stable_pose)
perpendicular_table = (np.abs(grasp_approach_table_angle) < max_grasp_approach_table_angle)
if not perpendicular_table:
continue
# 检查每个抓取路径上是否有碰撞
collision_free = False
# 对每个抓取补偿角度,这里只有10度,即原始抓取角度旋转10度再检测是否碰撞
for phi_offset in phi_offsets:
rotated_grasp = aligned_grasp.grasp_y_axis_offset(phi_offset)
collides = collision_checker.collides_along_approach(rotated_grasp, approach_dist, delta_approach)
# 这里搞不懂,为什么只要一个补偿角度无碰撞就可以定义为无碰撞??
if not collides:
collision_free = True
break
# 把生成的抓取位姿存储在一个缓冲数据结构中
candidate_grasps_dict[obj.key][stable_pose.id].append(GraspInfo(aligned_grasp, collision_free))
# 如果指定了显示则显示该抓取位姿
if collision_free and config['vis']['candidate_grasps']:
logging.info('Grasp %d' %(aligned_grasp.id))
vis.figure()
vis.gripper_on_object(gripper, aligned_grasp, obj, stable_pose.T_obj_world)
vis.show()
# 存储缓冲数据
logging.info('Saving to file')
pkl.dump(candidate_grasps_dict, open(grasp_cache_filename, 'wb'))
# 2. 对于数据集中的每一个抓取姿势,渲染一个图像数据集,并且关联夹爪位姿与图像坐标
# 设置变量
# 物体种类
obj_category_map = {}
# 姿势种类
pose_category_map = {}
cur_pose_label = 0
cur_obj_label = 0
cur_image_label = 0
# 渲染数据集中每个物体的每个稳定位姿的图像
render_modes = [RenderMode.SEGMASK, RenderMode.DEPTH_SCENE]
for dataset in datasets:
logging.info('Generating data for dataset %s' %(dataset.name))
# 对每个物体进行迭代
object_keys = dataset.object_keys
for obj_key in object_keys:
obj = dataset[obj_key]
if obj.key not in target_object_keys[dataset.name]:
continue
# 读入物体mesh的一个稳定姿态
stable_poses = dataset.stable_poses(obj.key)
for i, stable_pose in enumerate(stable_poses):
# 如果稳定姿态有效则渲染图片
if stable_pose.p > stable_pose_min_p:
# log progress
logging.info('Rendering images for object %s in %s' %(obj.key, stable_pose.id))
# add to category maps
if obj.key not in obj_category_map.keys():
obj_category_map[obj.key] = cur_obj_label
pose_category_map['%s_%s' %(obj.key, stable_pose.id)] = cur_pose_label
# 读取候选的抓取和判据,candidate_grasps_dict是第一步中生成的候选抓取字典
candidate_grasp_info = candidate_grasps_dict[obj.key][stable_pose.id]
candidate_grasps = [g.grasp for g in candidate_grasp_info]
# grasp_metrics这个方法可以直接生成每个抓取姿势的判据值,很关键
grasp_metrics = dataset.grasp_metrics(obj.key, candidate_grasps, gripper=gripper.name)
# 计算物体相对于桌面的位姿
T_obj_stp = stable_pose.T_obj_table.as_frames('obj', 'stp')
T_obj_stp = obj.mesh.get_T_surface_obj(T_obj_stp)
# 利用随机变量采样图片
T_table_obj = RigidTransform(from_frame='table', to_frame='obj')
# SceneObject类只是把背景和前景两个参数打包一下而已
scene_objs = {'table': SceneObject(table_mesh, T_table_obj)}
# 这个函数定义在meshpy模块中,具体后面解释
# env_rv_params这个是相机参数
urv = UniformPlanarWorksurfaceImageRandomVariable(obj.mesh,
render_modes,
'camera',
env_rv_params,
stable_pose=stable_pose,
scene_objs=scene_objs)
render_start = time.time()
# urv.rvs定义在父类RandomVariable,在aotulab_core包中,其实际也是调用了子类中的sample方法
# 具体的sample方法后面具体分析
# 这里的size是采样的个数
render_samples = urv.rvs(size=image_samples_per_stable_pose)
render_stop = time.time()
logging.info('Rendering images took %.3f sec' %(render_stop - render_start))
# tally total amount of data
num_grasps = len(candidate_grasps)
num_images = image_samples_per_stable_pose
num_save = num_images * num_grasps
logging.info('Saving %d datapoints' %(num_save))
# 对于物体上的每个候选抓取位置计算夹爪到图像空间的投影
for render_sample in render_samples:
# 读取图片
binary_im = render_sample.renders[RenderMode.SEGMASK].image
depth_im_table = render_sample.renders[RenderMode.DEPTH_SCENE].image
# 读取相机参数
T_stp_camera = render_sample.camera.object_to_camera_pose
shifted_camera_intr = render_sample.camera.camera_intr
# 读取像素坐标
cx = depth_im_table.center[1]
cy = depth_im_table.center[0]
# compute intrinsics for virtual camera of the final
# cropped and rescaled images
# 计算虚拟相机的内参数,为了最终裁剪和缩放图像
camera_intr_scale = float(im_final_height) / float(im_crop_height)
cropped_camera_intr = shifted_camera_intr.crop(im_crop_height, im_crop_width, cy, cx)
final_camera_intr = cropped_camera_intr.resize(camera_intr_scale)
# 为每个抓取点生成小图,candidate_grasp_info是第一步生成的抓取和碰撞字典
for grasp_info in candidate_grasp_info:
# 读取抓取点信息和碰撞信息
grasp = grasp_info.grasp
collision_free = grasp_info.collision_free
# 获取抓取位姿
# T_obj_camera为物体坐标系到相机坐标系的变换矩阵
T_obj_camera = T_stp_camera * T_obj_stp.as_frames('obj', T_stp_camera.from_frame)
# 把抓取点投影到相机空间,Grasp2D对象在gqcnn包内,Grasp2D对象后面进行分析
grasp_2d = grasp.project_camera(T_obj_camera, shifted_camera_intr)
# 对图片进行平移和旋转,对齐到夹爪中心点,且夹取轴在x轴
dx = cx - grasp_2d.center.x
dy = cy - grasp_2d.center.y
translation = np.array([dy, dx])
binary_im_tf = binary_im.transform(translation, grasp_2d.angle)
depth_im_tf_table = depth_im_table.transform(translation, grasp_2d.angle)
# 对图像进行裁剪到目标大小
binary_im_tf = binary_im_tf.crop(im_crop_height, im_crop_width)
depth_im_tf_table = depth_im_tf_table.crop(im_crop_height, im_crop_width)
# 缩放图像到最终目标大小
binary_im_tf = binary_im_tf.resize((im_final_height, im_final_width), interp='nearest')
depth_im_tf_table = depth_im_tf_table.resize((im_final_height, im_final_width))
# 组成一个抓取位姿矩阵
# np.r_用于把几个矩阵竖着叠起来
hand_pose = np.r_[grasp_2d.center.y,
grasp_2d.center.x,
grasp_2d.depth,
grasp_2d.angle,
grasp_2d.center.y - shifted_camera_intr.cy,
grasp_2d.center.x - shifted_camera_intr.cx,
grasp_2d.width_px]
# 存储所有数据到缓冲区
tensor_datapoint['depth_ims_tf_table'] = depth_im_tf_table.raw_data
tensor_datapoint['obj_masks'] = binary_im_tf.raw_data
tensor_datapoint['hand_poses'] = hand_pose
tensor_datapoint['collision_free'] = collision_free
tensor_datapoint['obj_labels'] = cur_obj_label
tensor_datapoint['pose_labels'] = cur_pose_label
tensor_datapoint['image_labels'] = cur_image_label
for metric_name, metric_val in grasp_metrics[grasp.id].iteritems():
coll_free_metric = (1 * collision_free) * metric_val
tensor_datapoint[metric_name] = coll_free_metric
tensor_dataset.add(tensor_datapoint)
# 更新图像计数标签
cur_image_label += 1
# 更新稳定姿态计数标记
cur_pose_label += 1
# 进行显示的垃圾回收
gc.collect()
# 更新物体计数标签
cur_obj_label += 1
# 进行显示的垃圾回收
gc.collect()
# 输出内存里的最终数据
# 这里使用np.savez_savez_compressed和np.savez存有区别,前者经过压缩后者没有压缩
tensor_dataset.flush()
# 存储物体种类和姿态种类映射文件,用来对应存储的最终数据和原始物体名字直接的映射
obj_cat_filename = os.path.join(output_dir, 'object_category_map.json')
json.dump(obj_category_map, open(obj_cat_filename, 'w'))
pose_cat_filename = os.path.join(output_dir, 'pose_category_map.json')
json.dump(pose_category_map, open(pose_cat_filename, 'w'))