cam_to_ir_label.py关键代码讲解

def _work(process_id, infer_dataset, args):
    visualize_intermediate_cam = False
    databin = infer_dataset[process_id]
    infer_data_loader = DataLoader(databin, shuffle=False, num_workers=0, pin_memory=False)

    for iter, pack in enumerate(infer_data_loader):
        img_name = voc12.dataloader.decode_int_filename(pack['name'][0])
        img = pack['img'][0].numpy()
        cam_dict = np.load(os.path.join(args.cam_out_dir, img_name + '.npy'), allow_pickle=True).item()

        cams = cam_dict['high_res']
        keys = np.pad(cam_dict['keys'] + 1, (1, 0), mode='constant')  # 从1开始编号,将0保留给背景类别。

        # 1. find confident fg & bg
        fg_conf_cam = np.pad(cams, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=args.conf_fg_thres)
        # 在 cams 数组的第一个维度(高度)前面添加1个元素,其余维度不添加。这是为了在前景区域周围添加一个边界,以便后续处理。
        # args.conf_fg_thres定义前景区域的置信度阈值。
        # 在 CAM 的周围添加一个边界,并使用阈值来填充边界区域,使得前景区域的置信度超过阈值,而背景区域低于阈值。
        fg_conf_cam = np.argmax(fg_conf_cam, axis=0)
        # 找到每个像素位置在不同类别中哪个类别的置信度最高,从而得到一个二维数组,表示在前景区域中哪个类别具有最高的置信度。


        pred = imutils.crf_inference_label(img, fg_conf_cam, n_labels=keys.shape[0])
        # 通过 CRF 推理对前景区域的置信度图像进行细化,得到更精确的图像标签,以便更准确地预测前景类别。

        fg_conf = keys[pred]  # 将每个像素位置在 keys 数组中查找对应的预测类别标识符,从而得到一个表示前景区域中每个像素预测类别的数组
        bg_conf_cam = np.pad(cams, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=args.conf_bg_thres)
        bg_conf_cam = np.argmax(bg_conf_cam, axis=0)
        pred = imutils.crf_inference_label(img, bg_conf_cam, n_labels=keys.shape[0])
        bg_conf = keys[pred]

        # 2. combine confident fg & bg
        conf = fg_conf.copy()
        conf[fg_conf == 0] = 255  # 数组中前景类别为背景的位置的值设置为 255,以进行标记。
        conf[bg_conf + fg_conf == 0] = 0  # 数组中前景和背景类别都不存在的位置的值设置为 0

        imageio.imwrite(os.path.join(args.ir_label_out_dir, img_name + '.png'), conf.astype(np.uint8))


        if process_id == args.num_workers - 1 and iter % (len(databin) // 20) == 0:
            print("%d " % ((5 * iter + 1) // (len(databin) // 20)), end='')

你可能感兴趣的:(人工智能,语义分割,图像处理,pytorch,python)