Homography数据集制作

# coding=utf-8

import cv2
from PIL import Image
from pylab import *
import numpy as np
import os

RAW_IMAGE_PATH = "/home/wgp/视频/DJI_0071/"
SAVE_FILE_PATH = "/home/wgp/视频/ha/"
GT_FILE_NAME = "gt.txt"
PTS_FILE_NAME = "pts1.txt"
FILE_LIST_NAME = "filelist.txt"
MODE = "train"
INTERVAL = 3


def draw_points(image1, image2):
    im1 = array(Image.open(image1))
    imshow(im1)
    print '左图:请选择四个点点击-->'
    x1 = ginput(4)
    print '选中四点坐标为:', x1
    im2 = array(Image.open(image2))
    imshow(im2)
    print '右图:请选择四个点点击-->'
    x2 = ginput(4)
    print '选中四点坐标为:', x2
    return x1, x2


def compute_homography(sourcePoints, destinationPoints):
    H = cv2.getPerspectiveTransform(np.float32(sourcePoints), np.float32(destinationPoints))
    return H


def compute_save_ground_truth(sourcePoints, destinationPoints, pts1_file, gt_file, filenames_file, index):
    f_pts1 = open(pts1_file, 'ab')
    f_gt = open(gt_file, 'ab')
    f_file_list = open(filenames_file, 'ab')

    pts1 = np.array(sourcePoints).flatten().astype(np.float32)
    gt = np.subtract(np.array(destinationPoints), np.array(sourcePoints))
    gt = np.array(gt).flatten().astype(np.float32)

    np.savetxt(f_gt, [gt], delimiter=' ')
    np.savetxt(f_pts1, [pts1], delimiter=' ')
    f_file_list.write('%s %s\n' % (str(index) + '.jpg', str(index) + '.jpg'))


def get_all_images(file_path):
    raw_image_list = remove_hidden_file(os.listdir(file_path))
    raw_image_list.sort(key=lambda x: int(x[:-4]))
    return raw_image_list


def remove_hidden_file(directory_list):
    if '.' in directory_list:
        directory_list.remove('.')
    if '..' in directory_list:
        directory_list.remove('.')
    if '.DS_Store' in directory_list:
        directory_list.remove('.DS_Store')
    return directory_list


def gen_data(interval):
    if not os.path.exists(SAVE_FILE_PATH):
        os.makedirs(SAVE_FILE_PATH)
    if MODE == 'train':
        f_pts1 = os.path.join(SAVE_FILE_PATH, "train_" + PTS_FILE_NAME)
        f_gt = os.path.join(SAVE_FILE_PATH, "train_" + GT_FILE_NAME)
        f_file_list = os.path.join(SAVE_FILE_PATH, "train_" + FILE_LIST_NAME)
    else:
        f_pts1 = os.path.join(SAVE_FILE_PATH, "test_" + PTS_FILE_NAME)
        f_gt = os.path.join(SAVE_FILE_PATH, "test_" + GT_FILE_NAME)
        f_file_list = os.path.join(SAVE_FILE_PATH, "test_" + FILE_LIST_NAME)

    raw_image_list = get_all_images(RAW_IMAGE_PATH)
    for img_index in range(len(raw_image_list) - interval):
        left_image = RAW_IMAGE_PATH + raw_image_list[img_index]
        right_image = RAW_IMAGE_PATH + raw_image_list[img_index + interval]

        # print(left_image, right_image)
        sourcePoints, destinationPoints = draw_points(left_image, right_image)

        H = compute_homography(sourcePoints, destinationPoints)
        print("单应矩阵:", H)

        compute_save_ground_truth(sourcePoints, destinationPoints, f_pts1, f_gt, f_file_list, img_index)


if __name__ == '__main__':
    gen_data(INTERVAL)

你可能感兴趣的:(工具代码)