k-means+canopy+vgg16模型图像分类工具

流程

  1. 取vgg16模型fc2层向量保存到image.db文件中
  2. 使用canopy+欧氏距离粗略估计k值
  3. 使用k-means算法分类

获取图片向量(代码摘自 《自制AI图像搜索引擎》)

    private INDArray getImgFeature(File imgFile) throws IOException {
        NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
        INDArray imageArray = loader.asMatrix(imgFile);
        DataNormalization scaler = new VGG16ImagePreProcessor();
        scaler.transform(imageArray);
        Map<String, INDArray> map = vgg16Model.feedForward(imageArray, false);
        INDArray feature = map.get("fc2");
        return feature;
    }

    private double[] INDArray2DoubleArray(INDArray indArr) {
        String indArrStr = indArr.toString().replace("[", "").replace("]", "");
        String[] strArr = indArrStr.split(",");
        int len = strArr.length;
        double[] doubleArr = new double[len];
        for (int i = 0; i < len; i++) {
            doubleArr[i] = Double.parseDouble(strArr[i]);
        }
        return doubleArr;
    }
public class Classify {

    public static List<Vector> getVectorListFromDB(String dbPath) {
        DB db = DBMaker.fileDB(dbPath).make();
        ConcurrentMap<String, double[]> map = db.hashMap("feat_map", Serializer.STRING, Serializer.DOUBLE_ARRAY).open();
        List<Vector> vecs = new ArrayList<Vector>();
        for (String key : map.keySet()) {
            double[] val = map.get(key);
            // norm2
            val = Utils.normalizeL2(val);
            Vector vec = new Vector(key, val);
            vecs.add(vec);
        }
        db.close();
        return vecs;
    }

    public static void kmeans(String dbPath,String fromPath,String distPath) throws IOException {
        Long time=System.currentTimeMillis();
        List<Vector> dataset = Classify.getVectorListFromDB(dbPath);
        Canopy canopy = new Canopy(new ArrayList<>(dataset));
        int m = canopy.cluster();
        System.out.println("预计 "+m+" 个分类");
        Kmeans cu = new Kmeans(new ArrayList<>(dataset));
        cu.execute(m);

        ArrayList<ArrayList<Vector>> clusters = cu.getCluster();
        int i=0;
        for (ArrayList<Vector> cluster : clusters) {

            for (Vector vector : cluster) {
                System.out.println("move "+vector.getKey()+" "+i);
                moveFile(vector.getKey(),i,fromPath,distPath);
            }

            i++;

        }

        System.out.println("耗时 "+((System.currentTimeMillis()-time)/1000/60)+" min");
    }

    public static void moveFile(String from,int to,String fromPath,String distPath) throws IOException {
        String baseDir=fromPath;
        String distDir=distPath;
        File toF = new File(distDir+File.separator+to);
        if(!toF.isDirectory()){
            toF.mkdir();
        }
        FileUtils.copyFile(new File(baseDir+ File.separator+from),new File(toF.getPath()+File.separator+from));
    }

}

现成打包工具

使用说明

        Option help = new Option("h",false,"显示帮助信息");
        Option model=Option.builder("m").hasArg().argName("model").desc("模型路径名").build();
        Option database = Option.builder("d").hasArg().argName("database").desc("图像特征库路径名").build();
        Option img = Option.builder("i").hasArg().argName("imgdir").desc("用于构建特征库的图像文件夹路径全名").build();
        Option dist = Option.builder("t").hasArg().argName("dist").desc("分类后的目标文件夹").build();
        Option useDb = Option.builder("b").hasArg(false).desc("使用现成的db").build();

示例

对mp文件夹下的图片分类到test文件夹
java -jar .\GenerateImgsFeatDBTool-1.0-SNAPSHOT.jar -m D:\work\vgg16.zip -d D:\pic\image.db -i D:\pic\mp -t D:\pic\test

k-means算法生成的结果每次都不太一样,如果对分类结果不满意可以,加-b参数会使用上次生成的db文件,加快分类速度
java -jar .\GenerateImgsFeatDBTool-1.0-SNAPSHOT.jar -m D:\work\vgg16.zip -d D:\pic\image.db -i D:\pic\mp -t D:\pic\test -b

链接:https://pan.baidu.com/s/1hNf4YarZ5nnZkjmcMnPrng
提取码:yuc5

你可能感兴趣的:(python,kmeans,分类,算法,canopy,vgg16)