在Android中实现airtest的特征点识别类

就是Airtest的java语言版本,其参数我没有细调。

代码基于opencv 4.2, 编译参考:https://blog.csdn.net/enlangs/article/details/105344970

import org.opencv.calib3d.Calib3d;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.DMatch;
import org.opencv.core.Mat;
import org.opencv.core.MatOfDMatch;
import org.opencv.core.MatOfKeyPoint;
import org.opencv.core.MatOfPoint2f;
import org.opencv.core.Point;
import org.opencv.core.Rect;
import org.opencv.features2d.BFMatcher;
import org.opencv.features2d.DescriptorMatcher;
import org.opencv.features2d.Feature2D;
import org.opencv.features2d.KAZE;
import org.opencv.imgproc.Imgproc;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

import static org.opencv.core.Core.NORM_L1;

/**
 * 基于特征点的识别基类: KAZE.
 */
public class KeypointMatching {
    private static final String TAG = KeypointMatching.class.getSimpleName();

    // 日志中的方法名
    public static final String METHOD_NAME = "KAZE";
    // 参数: FILTER_RATIO为SIFT优秀特征点过滤比例值(0-1范围,建议值0.4-0.6)
    private static final double FILTER_RATIO = 0.59;
    // 参数: SIFT识别时只找出一对相似特征点时的置信度(confidence)
    private static final double ONE_POINT_CONFI = 0.5;


    private Mat im_source;
    private Mat im_search;
    private double threshold = 0.8;
    private boolean rgb = true;
    protected Feature2D detector;
    protected DescriptorMatcher matcher;
    public boolean needCalConfidence = false;


    public KeypointMatching(Mat im_source, Mat im_search) {
        this(im_source, im_search, 0.8, false);
    }


    public KeypointMatching(Mat im_source, Mat im_search, double threshold, boolean rgb) {
        this.im_source = im_source;
        this.im_search = im_search;
        this.threshold = threshold;
        this.rgb = rgb;
    }


    /**
     * 参考实现:
     * https://stackoverflow.com/questions/35428440/java-opencv-extracting-good-matches-from-knnmatch
     *
     * @throws NoMatchPointError 没有适配到一次
     */
    private void get_key_points(MatOfKeyPoint kp_sch, MatOfKeyPoint kp_src, LinkedList<DMatch> good) throws NoMatchPointError {
        this.init_detector();

        Mat des_src = new Mat();
        Mat des_sch = new Mat();

        this.get_keypoints_and_descriptors(this.im_source, kp_src, des_src);
        this.get_keypoints_and_descriptors(this.im_search, kp_sch, des_sch);

        // When apply knnmatch , make sure that number of features in both test and
        // query image is greater than or equal to number of nearest neighbors in knn match.
        if (kp_src.toList().size() < 2 || kp_sch.toList().size() < 2) {
            throw new NoMatchPointError("Not enough feature points in input images !");
        }

        // match descriptors (特征值匹配)
        List<MatOfDMatch> knnMatches = new LinkedList<>();
        this.match_keypoints(des_sch, des_src, knnMatches);

        // good为特征点初选结果,剔除掉前两名匹配太接近的特征点,不是独特优秀的特征点直接筛除(多目标识别情况直接不适用)
        good.clear();
        for (int i = 0; i < knnMatches.size(); i++) {
            if (knnMatches.get(i).rows() > 1) {
                DMatch[] matches = knnMatches.get(i).toArray();
                if (matches[0].distance < FILTER_RATIO * matches[1].distance) {
                    good.add(matches[0]);
                }
            }
        }

        // good点需要去除重复的部分,(设定源图像不能有重复点)去重时将src图像中的重复点找出即可
        // 去重策略:允许搜索图像对源图像的特征点映射一对多,不允许多对一重复(即不能源图像上一个点对应搜索图像的多个点)
        List<DMatch> good_diff = new LinkedList<>();
        List<Point> diff_good_point = new ArrayList<>();
        for (DMatch m : good) {
            Point diff_point = kp_src.toList().get(m.trainIdx).pt;
            if (!diff_good_point.contains(diff_point)) {
                good_diff.add(m);
                diff_good_point.add(diff_point);
            }
        }

        good.clear();
        good.addAll(good_diff);
    }

    /**
     * Match descriptors (特征值匹配).
     *
     * @param des_sch 待搜索图片描述
     * @param des_src 原图描述
     * @param matches 匹配结果
     */
    protected void match_keypoints(Mat des_sch, Mat des_src, List<MatOfDMatch> matches) {
        // 匹配两个图片中的特征点集,k=2表示每个特征点取出2个最匹配的对应点:
        this.matcher.knnMatch(des_sch, des_src, matches, 2);
    }

    /**
     * 获取图像特征点和描述符.
     *
     * @param im_source    原始图片数据
     * @param keypoints    对应的特征点
     * @param descriptors1 对应的描述符
     */
    protected void get_keypoints_and_descriptors(Mat im_source, MatOfKeyPoint keypoints, Mat descriptors1) {
        Mat mask = new Mat();
        this.detector.detectAndCompute(im_source, mask, keypoints, descriptors1);
    }

    //     """Init keypoint detector object."""
    protected void init_detector() {
        this.detector = KAZE.create();
        // # create BFMatcher object:
        this.matcher = new BFMatcher(NORM_L1); // # cv2.NORM_L1 cv2.NORM_L2 cv2.NORM_HAMMING(not useable)
    }


    /**
     * 基于kaze进行图像识别,只筛选出最优区域.
     *
     * @return 识别结果
     */
    public OriginResult find_best_result() throws MatchException {
        // 第一步:检验图像是否正常:
        if (!this.check_image_valid(this.im_search, this.im_search)) {
            return null;
        }
        // 第二步:获取特征点集并匹配出特征点对: 返回值 good, pypts, kp_sch, kp_src

        MatOfKeyPoint kp_src = new MatOfKeyPoint();
        MatOfKeyPoint kp_sch = new MatOfKeyPoint();
        LinkedList<DMatch> good = new LinkedList<>();
        this.get_key_points(kp_sch, kp_src, good);

        OriginResult origin_result;
        // 第三步:根据匹配点对(good),提取出来识别区域:
        if (good.size() <= 1) {
            // 匹配点对为0,无法提取识别区域;为1则无法获取目标区域,直接返回None作为匹配结果:
            return null;
        } else if (good.size() <= 3) {
            // 匹配点对为2或3,根据点对求出目标区域,据此算出可信度:
            if (good.size() == 2) {
                origin_result = this.handle_two_good_points(kp_sch, kp_src, good);
            } else {
                origin_result = this.handle_three_good_points(kp_sch, kp_src, good);
            }
            // 某些特殊情况下直接返回None作为匹配结果:
            if (origin_result == null) {
                return null;
            }
        } else {
            // 匹配点对 >= 4个,使用单矩阵映射求出目标区域,据此算出可信度:
            origin_result = this.many_good_pts(kp_sch, kp_src, good);
        }

        if (!needCalConfidence) {
            return origin_result;
        }

        // 第四步:根据识别区域,求出结果可信度,并将结果进行返回:
        // 对识别结果进行合理性校验: 小于5个像素的,或者缩放超过5倍的,一律视为不合法直接raise.
        this.target_error_check(origin_result);
        // 将截图和识别结果缩放到大小一致,准备计算可信度
        Rect r = new Rect((int) origin_result.x_min, (int) origin_result.y_min, (int) origin_result.w, (int) origin_result.h);
        Mat target_img = this.im_source.submat(r);
//        HighGui.imshow("target_img", target_img);
//        Mat resize_img = new Mat((int) origin_result.h, (int) origin_result.w, target_img.type());
//        Imgproc.resize(target_img, resize_img, new Size((int) origin_result.w, (int) origin_result.h));
//        HighGui.imshow("resize_img", resize_img);
//        HighGui.waitKey(0);
        origin_result.confidence = this.cal_confidence(target_img);

        this.generate_result(origin_result, origin_result.confidence);
//        Log.d(TAG, "[%s] threshold=%s, result=%s" % (METHOD_NAME, threshold, best_match));

        System.out.println(origin_result);

        if (origin_result.confidence >= this.threshold) {
            return origin_result;
        } else {
            return null;
        }
    }

    private void generate_result(OriginResult origin_result, double confidence) {
    }

    /**
     * 取得得到单应性矩阵,为获取坐标做准备 https://www.cnblogs.com/573177885qq/p/4789199.html
     *
     * @param pts_sch      搜索模板坐标点
     * @param pts_src      原图坐标点
     * @param matches_mask 输出
     * @return 单一矩阵
     */
    private Mat find_homography(List<Point> pts_sch, List<Point> pts_src, Mat matches_mask) {
        // convertion of data types - there is maybe a more beautiful way
        MatOfPoint2f pts1Mat_sch = new MatOfPoint2f();
        pts1Mat_sch.fromList(pts_sch);
        MatOfPoint2f pts2Mat_src = new MatOfPoint2f();
        pts2Mat_src.fromList(pts_src);


        // M是转化矩阵
        return Calib3d.findHomography(pts1Mat_sch, pts2Mat_src, Calib3d.RANSAC, 5.0, matches_mask, 2000, 0.995);
    }

    /**
     * 处理两对特征点的情况.
     *
     * @param kp_sch       搜索图片的特征点
     * @param kp_src       原图特征点
     * @param good_matches 匹配列表
     * @return 返回识别结果
     */
    private OriginResult handle_two_good_points(MatOfKeyPoint kp_sch, MatOfKeyPoint kp_src, LinkedList<DMatch> good_matches) {
        Point pts_sch1 = kp_sch.toArray()[good_matches.get(0).queryIdx].pt;
        Point pts_sch2 = kp_sch.toArray()[good_matches.get(1).queryIdx].pt;
        Point pts_src1 = kp_src.toArray()[good_matches.get(0).trainIdx].pt;
        Point pts_src2 = kp_src.toArray()[good_matches.get(1).trainIdx].pt;
        return get_origin_result_with_two_points(pts_sch1, pts_sch2, pts_src1, pts_src2);
    }

    /**
     * 处理三对特征点的情况
     *
     * @param kp_sch       搜索图片的特征点
     * @param kp_src       原图特征点
     * @param good_matches 匹配列表
     * @return 返回识别结果
     */
    private OriginResult handle_three_good_points(MatOfKeyPoint kp_sch, MatOfKeyPoint kp_src, LinkedList<DMatch> good_matches) {
        // 拿出sch和src的两个点(点1)和(点2点3的中点),
        // 然后根据两个点原则进行后处理(注意ke_sch和kp_src以及queryIdx和trainIdx):
        Point pts_sch1 = kp_sch.toArray()[good_matches.get(0).queryIdx].pt;
        Point pts_sch2 = kp_sch.toArray()[good_matches.get(1).queryIdx].pt;
        Point pts_sch3 = kp_sch.toArray()[good_matches.get(2).queryIdx].pt;
        pts_sch2.x = (pts_sch2.x + pts_sch3.x) / 2;
        pts_sch2.y = (pts_sch2.y + pts_sch3.y) / 2;
        Point pts_src1 = kp_src.toArray()[good_matches.get(0).trainIdx].pt;
        Point pts_src2 = kp_src.toArray()[good_matches.get(1).trainIdx].pt;
        Point pts_src3 = kp_src.toArray()[good_matches.get(2).trainIdx].pt;
        pts_src2.x = (pts_src2.x + pts_src3.x) / 2;
        pts_src2.y = (pts_src2.y + pts_src3.y) / 2;
        return get_origin_result_with_two_points(pts_sch1, pts_sch2, pts_src1, pts_src2);
    }

    /**
     * 特征点匹配点对数目>=4个,可使用单矩阵映射,求出识别的目标区域.
     *
     * @param kp_sch       搜索图特征点
     * @param kp_src       原图特征点
     * @param good_matches 适配值
     * @return 识别结果
     */
    private OriginResult many_good_pts(MatOfKeyPoint kp_sch, MatOfKeyPoint kp_src, LinkedList<DMatch> good_matches) {
        List<Point> pts_sch = new ArrayList<>();
        List<Point> pts_src = new ArrayList<>();
        for (DMatch good_match : good_matches) {
            pts_sch.add(kp_sch.toList().get(good_match.queryIdx).pt);
            pts_src.add(kp_src.toList().get(good_match.trainIdx).pt);
        }


        Mat matches_mask = new Mat();
        this.find_homography(pts_sch, pts_src, matches_mask);

        // 从good中间筛选出更精确的点(假设good中大部分点为正确的,由ratio=0.7保障)
        LinkedList<DMatch> selected = new LinkedList<>();
        for (int i = 0; i < good_matches.size(); i++) {
            if (matches_mask.get(i, 0)[0] != 0.0) {
                selected.add(good_matches.get(i));
            }
        }

        // 针对所有的selected点再次计算出更精确的转化矩阵M来
        List<Point> sch_pts = new ArrayList<>();
        List<Point> img_pts = new ArrayList<>();
        for (DMatch dMatch : selected) {
            sch_pts.add(kp_sch.toList().get(dMatch.queryIdx).pt);
            img_pts.add(kp_src.toList().get(dMatch.trainIdx).pt);
        }
        Mat mask = new Mat();
        Mat M1 = this.find_homography(sch_pts, img_pts, mask);

        // 计算四个角矩阵变换后的坐标,也就是在大图中的目标区域的顶点坐标:
        long w = this.im_search.cols(), h = this.im_search.rows();
        long w_s = this.im_source.cols(), h_s = this.im_source.rows();
        Mat objCorners = new Mat(4, 1, CvType.CV_32FC2);
        float[] objCornersData = new float[(int) (objCorners.total() * objCorners.channels())];
        objCorners.get(0, 0, objCornersData);
        objCornersData[0] = 0;
        objCornersData[1] = 0;
        objCornersData[2] = 0;
        objCornersData[3] = h - 1;
        objCornersData[4] = w - 1;
        objCornersData[5] = h - 1;
        objCornersData[6] = w - 1;
        objCornersData[7] = 0;
        objCorners.put(0, 0, objCornersData);
        Mat sceneCorners = new Mat();

        // 从单应性矩阵获取坐标,https://www.cnblogs.com/573177885qq/p/4789199.html
        Core.perspectiveTransform(objCorners, sceneCorners, M1);
        float[] sceneCornersData = new float[(int) (sceneCorners.total() * sceneCorners.channels())];
        sceneCorners.get(0, 0, sceneCornersData);

        Point lt = new Point(sceneCornersData[0], sceneCornersData[1]), br = new Point(sceneCornersData[4], sceneCornersData[5]);

        // 注意:虽然4个角点有可能越出source图边界,但是(根据精确化映射单映射矩阵M线性机制)中点不会越出边界
        Point middle_point = new Point((int) (lt.x + br.x) / 2.0, (int) (lt.y + br.y) / 2.0);

        // 考虑到算出的目标矩阵有可能是翻转的情况,必须进行一次处理,确保映射后的“左上角”在图片中也是左上角点:
        double x_min = (int) Math.min(lt.x, br.x), x_max = (int) Math.max(lt.x, br.x);
        double y_min = (int) Math.min(lt.y, br.y), y_max = (int) Math.max(lt.y, br.y);

        // 挑选出目标矩形区域可能会有越界情况,越界时直接将其置为边界:
        // 超出左边界取0,超出右边界取w_s-1,超出下边界取0,超出上边界取h_s-1
        // 当x_min小于0时,取0。  x_max小于0时,取0。
        x_min = Math.max(x_min, 0);
        x_max = Math.max(x_max, 0);
        // 当x_min大于w_s时,取值w_s-1。  x_max大于w_s-1时,取w_s-1。
        x_min = Math.min(x_min, w_s - 1);
        x_max = Math.min(x_max, w_s - 1);

        // 当y_min小于0时,取0。  y_max小于0时,取0。
        y_min = Math.max(y_min, 0);
        y_max = Math.max(y_max, 0);
        // 当y_min大于h_s时,取值h_s-1。  y_max大于h_s-1时,取h_s-1。
        y_min = Math.min(y_min, h_s - 1);
        y_max = Math.min(y_max, h_s - 1);

        // 目标区域的角点,按左上、左下、右下、右上点序:(x_min,y_min)(x_min,y_max)(x_max,y_max)(x_max,y_min)
        return new OriginResult(middle_point, x_min, x_max, y_min, y_max, w, h);
    }


    /**
     * 返回两对有效匹配特征点情形下的识别结果.
     *
     * @param pts_sch1 搜索图片坐标1
     * @param pts_sch2 搜索图片坐标2
     * @param pts_src1 原图片坐标1
     * @param pts_src2 原图片坐标2
     * @return 返回识别结果
     */
    private OriginResult get_origin_result_with_two_points(Point pts_sch1, Point pts_sch2, Point pts_src1, Point pts_src2) {
        // 先算出中心点(在self.im_source中的坐标):
        Point middle_point = new Point((int) (pts_src1.x + pts_src2.x) / 2.0, (int) (pts_src1.y + pts_src2.y) / 2.0);

        if (pts_sch1.x == pts_sch2.x || pts_sch1.y == pts_sch2.y || pts_src1.x == pts_src2.x || pts_src1.y == pts_src2.y) {
            return null;
        }
        // 计算x,y轴的缩放比例:x_scale、y_scale,从middle点扩张出目标区域:(注意整数计算要转成浮点数结果!)
        long width = this.im_search.cols(), height = this.im_search.rows();
        long width_s = this.im_search.cols(), height_s = this.im_search.rows();

        double x_scale = Math.abs(1.0 * (pts_src2.x - pts_src1.x) / (pts_sch2.x - pts_sch1.x));
        double y_scale = Math.abs(1.0 * (pts_src2.x - pts_src1.x) / (pts_sch2.x - pts_sch1.x));
        // 得到scale后需要对middle_point进行校正,并非特征点中点,而是映射矩阵的中点。
        Point sch_middle_point = new Point((pts_sch1.x + pts_sch2.x) / 2, (pts_sch1.y + pts_sch2.y) / 2);
        middle_point.x = (int) middle_point.x - (sch_middle_point.x - width / 2.0) * x_scale;
        middle_point.y = (int) middle_point.y - (sch_middle_point.y - height / 2.0) * y_scale;
        middle_point.x = Math.max(middle_point.x, 0);//超出左边界取0  (图像左上角坐标为0,0)
        middle_point.x = Math.min(middle_point.x, width_s - 1); // 超出右边界取w_s-1
        middle_point.y = Math.max(middle_point.y, 0); // 超出上边界取0
        middle_point.y = Math.min(middle_point.y, height_s - 1); // 超出下边界取h_s-1

        // 计算出来rectangle角点的顺序:左上角->左下角->右下角->右上角, 注意:暂不考虑图片转动
        // 超出左边界取0, 超出右边界取w_s-1, 超出下边界取0, 超出上边界取h_s-1
        double x_min = (int) Math.max(middle_point.x - (width * x_scale) / 2, 0), x_max = (int) Math.max(middle_point.x + (width * x_scale) / 2, width_s - 1);
        double y_min = (int) Math.max(middle_point.y - (height * y_scale) / 2, 0), y_max = (int) Math.min(middle_point.y + (height * y_scale) / 2, height_s - 1);

        return new OriginResult(middle_point, x_min, x_max, y_min, y_max, width, height);
    }

    private boolean check_image_valid(Mat source, Mat search) {
        return source != null && !source.empty() && search != null && !search.empty();
    }

    /**
     * 校验识别结果区域是否符合常理.
     *
     * @param origin_result 识别结果
     */
    private void target_error_check(OriginResult origin_result) throws MatchResultCheckError {
        double tar_width = origin_result.x_max - origin_result.x_min, tar_height = origin_result.y_max - origin_result.y_min;
        // 如果src_img中的矩形识别区域的宽和高的像素数<5,则判定识别失效。认为提取区域待不可能小于5个像素。(截图一般不可能小于5像素)
        if (tar_width < 5 || tar_height < 5) {
            throw new MatchResultCheckError("In src_image, Taget area: width or height < 5 pixel.");
        }
        // 如果矩形识别区域的宽和高,与sch_img的宽高差距超过5倍(屏幕像素差不可能有5倍),认定为识别错误。
        if (tar_width < 0.2 * origin_result.w || tar_width > 5 * origin_result.w || tar_height < 0.2 * origin_result.h || tar_height > 5 * origin_result.h) {
            throw new MatchResultCheckError("Target area is 5 times bigger or 0.2 times smaller than sch_img.");
        }
    }

    /**
     * 计算confidence.
     * 对比python测试通过。
     *
     * @param resize_img 比较原图
     * @return 准确率
     */
    private double cal_confidence(Mat resize_img) {
        double confidence;
        if (this.rgb) {
            confidence = cal_rgb_confidence(this.im_search, resize_img);
        } else {
            confidence = cal_ccoeff_confidence(this.im_search, resize_img);
        }
        // confidence修正
        confidence = (1 + confidence) / 2;
        return confidence;
    }

    /**
     * 同大小彩图计算相似度.
     *
     * @return 可信度
     */
    private double cal_rgb_confidence(Mat img_src_rgb, Mat img_sch_rgb) {
        // BGR三通道心理学权重:
        double[] weight = {0.114, 0.587, 0.299};

        List<Mat> src_bgr = new ArrayList<>();
        Core.split(img_src_rgb, src_bgr);

        List<Mat> sch_bgr = new ArrayList<>();
        Core.split(img_sch_rgb, sch_bgr);


        // 计算BGR三通道的confidence,存入bgr_confidence
        double[] bgr_confidence = {0, 0, 0};
        for (int i = 0; i < 3; i++) {
            Mat dst = new Mat();
            Imgproc.matchTemplate(src_bgr.get(i), sch_bgr.get(i), dst, Imgproc.TM_CCORR);
            bgr_confidence[i] = Core.minMaxLoc(dst).maxVal;
        }
        // 加权可信度
        return bgr_confidence[0] * weight[0] + bgr_confidence[1] * weight[1] + bgr_confidence[2] * weight[2];
    }

    /**
     * 求取两张图片的可信度,使用TM_CCOEFF_NORMED方法.
     *
     * @param img_src_rgb 原图
     * @param img_sch_rgb 搜索图
     * @return 可信度
     */
    private double cal_ccoeff_confidence(Mat img_src_rgb, Mat img_sch_rgb) {
        Mat im_source_gray = new Mat(), im_search_gray = new Mat(), dst = new Mat();
        img_mat_rgb_2_gray(img_src_rgb, im_source_gray);
        img_mat_rgb_2_gray(img_sch_rgb, im_search_gray);
        // Imgcodecs.imwrite("/Volumes/Data/xmaczone/tmp/python3/TestAirTest/im_source_gray.png", im_source_gray);
        // Imgcodecs.imwrite("/Volumes/Data/xmaczone/tmp/python3/TestAirTest/im_search_gray.png", im_search_gray);
        Imgproc.matchTemplate(im_source_gray, im_search_gray, dst, Imgproc.TM_CCOEFF_NORMED);
        return Core.minMaxLoc(dst).maxVal;
    }

    /**
     * rgb转灰度
     *
     * @param img_mat 原图
     * @param dst     灰度图
     */
    private void img_mat_rgb_2_gray(Mat img_mat, Mat dst) {
        Imgproc.cvtColor(img_mat, dst, Imgproc.COLOR_BGR2GRAY);
    }
}

参考:
airtest源码:https://github.com/AirtestProject/Airtest/blob/41a17899d26fc622d9374a70fcec35fdd2bab6c7/airtest/aircv/keypoint_base.py#L30
opencv java demo:https://github.com/opencv/opencv/blob/master/samples/java/tutorial_code/features2D/feature_homography/SURFFLANNMatchingHomographyDemo.java

你可能感兴趣的:(在Android中实现airtest的特征点识别类)