svm+hog图片分类 java版

参照链接http://blog.csdn.net/m_wbcg/article/details/75092947,做了个图像二分类的小测试,记录下

public class Svm_train {
    public void svm_train(){
        Integer ITERATION_NUM = 10000;
        String traintxt = "D:/mnist_data/traindata.txt";
        ArrayList img_path = new ArrayList();
        ArrayList img_label = new ArrayList();
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
        try {
            BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(traintxt)),
                    "UTF-8"));
            String linetxt = null;
            Integer nline = 0;
            while((linetxt = br.readLine()) !=null){
                nline++;
                String []path_label = linetxt.split("\t");
                //System.out.println(path_label[0] + ' ' + path_label[1]);
                img_path.add(path_label[0]);
                img_label.add(Float.valueOf(path_label[1]).floatValue());
            }
            br.close();
            Integer SAMPLE_COUNT = nline;
            Integer PICTURE_FEATURE_DIM = 1764;
            Mat data_mat = new Mat(SAMPLE_COUNT, PICTURE_FEATURE_DIM, CvType.CV_32FC1);
            Mat res_mat = new Mat(SAMPLE_COUNT, 1, CvType.CV_32SC1);

            //svm descriptors
            ArrayList<float[]> descriptors = new ArrayList<float[]>();
            for (Integer i=0;i;i++){
                System.out.println(img_path.get(i)+ "\tprocess");
                Mat src = Imgcodecs.imread(img_path.get(i));
                if(src.empty()){
                    System.out.println(img_path.get(i));
                    throw new Exception("no such picture");
                }
                Mat dst = new Mat();
                Imgproc.cvtColor(src, dst, Imgproc.COLOR_BGR2GRAY);
                Mat trainimg = dst.clone();
                Imgproc.resize(dst, trainimg, new Size(64, 64));

                HOGDescriptor hog = new HOGDescriptor(new Size(64, 64), new Size(16, 16), new Size(8, 8), new Size(8, 8), 9);
                MatOfFloat descriptorsOfMat = new MatOfFloat();
                hog.compute(trainimg, descriptorsOfMat);
                float[] descriptor = descriptorsOfMat.toArray();
                descriptors.add(descriptor);
            }

            for (Integer m = 0; m < descriptors.size(); m++) {
                for (int n = 0; n < descriptors.get(m).length; n++) {
//             System.out.println(descriptor.get(i)[j]);
                    data_mat.put(m, n, descriptors.get(m)[n]);
                }
                res_mat.put(m, 0, img_label.get(m));
            }

            SVM svm = SVM.create();
            svm.setType(SVM.C_SVC);
            svm.setKernel(SVM.LINEAR);
            svm.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, ITERATION_NUM, 1e-6));
            svm.train(data_mat, Ml.ROW_SAMPLE, res_mat);

            svm.save("D:/mnist_data/svm_java");

        }catch (Exception e){
            System.err.println("read err:" + e);
        }
    }
    public void svm_predict(){
        String testtxt = "D:/mnist_data/traindata_cc.txt";
        ArrayList img_path = new ArrayList();
        ArrayList img_label = new ArrayList();
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
        SVM svm = SVM.create();
        SVM model = svm.load("D:/mnist_data/svm_java");
        try {
            BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(testtxt)),
                    "UTF-8"));
            String linetxt = null;
            Integer nline = 0;
            while((linetxt = br.readLine()) !=null){
                nline++;
                String []path_label = linetxt.split("\t");
                //System.out.println(path_label[0] + ' ' + path_label[1]);
                img_path.add(path_label[0]);
                //img_label.add(Float.valueOf(path_label[1]).floatValue());
            }
            br.close();

            //svm test start
            ArrayList<float[]> descriptors = new ArrayList<float[]>();
            for (Integer i=0;i;i++) {
                Mat src_test = Imgcodecs.imread(img_path.get(i));
                if (src_test.empty()) {
                    throw new Exception("no such picture");
                }
                Mat dst_test = new Mat();
                Imgproc.cvtColor(src_test, dst_test, Imgproc.COLOR_BGR2GRAY);
                Mat testimg = dst_test.clone();
                Imgproc.resize(dst_test, testimg, new Size(64, 64));

                HOGDescriptor hog = new HOGDescriptor(new Size(64, 64), new Size(16, 16), new Size(8, 8), new Size(8, 8), 9);
                MatOfFloat descriptorsOfMat = new MatOfFloat();
                hog.compute(testimg, descriptorsOfMat);
                float[] descriptor = descriptorsOfMat.toArray();

                Mat testmat = new Mat(1, 1764, CvType.CV_32FC1);
                for (int j = 0; j < descriptor.length; j++) {
                    testmat.put(0, j, descriptor[j]);
                }
                float p = model.predict(testmat);
                System.out.println(img_path.get(i) + "\t" + p);
            }

        }catch (Exception e){
            System.err.println("read err:" + e);
        }
    }
    public static void  main(String[] args){
        Svm_train st = new Svm_train();
        st.svm_train();
        st.svm_predict();
    }
}

你可能感兴趣的:(log)