Lire源码解析一

Lucene image retrieval是以图搜图的java开源框架,这几天没什么事,就读了点源码,并写了点注释,特在这分享给大家。

这里主要给出的是BOVWBuilder.javaKmeans.javaCluster.java。就是用词频对特征进行编码,用到是BOF(bag of feature)模型,原理就是提取N张图片的特征(比如sift),放在一起就可以得到矩阵,然后对矩阵进行kmeans聚类,就会到到若干个聚类中心;对于新来的一副图像,我们分别计算该特征点与那个聚类中心最近,这样该聚类中心的量值就加1,这样就可以编码得到与聚类中心个数想等的维数向量。

一切都从BOVWBuilder中index函数开始...

BOVWBuilder.java(包含注释)

package lmc.imageretrieval.imageanalysis.bovw;

import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;

import javax.swing.ProgressMonitor;

import lmc.imageretrieval.imageanalysis.Histogram;
import lmc.imageretrieval.imageanalysis.LireFeature;
import lmc.imageretrieval.tools.DocumentBuilder;
import lmc.imageretrieval.utils.SerializationUtils;

import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexWriterConfig.OpenMode;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.Version;

public class BOVWBuilder {
    IndexReader reader;
    // number of documents used to build the vocabulary / clusters.
    private int numDocsForVocabulary = 500;
    private int numClusters = 512;
    private Cluster[] clusters = null;
    DecimalFormat df = (DecimalFormat) NumberFormat.getNumberInstance();
    private ProgressMonitor pm = null;

    protected LireFeature lireFeature;
    protected String localFeatureFieldName;
    protected String visualWordsFieldName;
    protected String localFeatureHistFieldName;
    protected String clusterFile;

    public static boolean DELETE_LOCAL_FEATURES = true;
    /**
     *
     * @param reader
     * @deprecated
     */
    public BOVWBuilder(IndexReader reader) {
        this.reader = reader;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters).
     *
     * @param reader               the reader used to open the Lucene index,
     * @param numDocsForVocabulary gives the number of documents for building the vocabulary (clusters).
     * @deprecated
     */
    public BOVWBuilder(IndexReader reader, int numDocsForVocabulary) {
        this.reader = reader;
        this.numDocsForVocabulary = numDocsForVocabulary;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters). The numClusters gives
     * the number of clusters k-means should find. Note that this number should be lower than the number of features,
     * otherwise an exception will be thrown while indexing.
     *
     * @param reader               the index reader
     * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary
     * @param numClusters          the size of the visual vocabulary
     * @deprecated
     */
    public BOVWBuilder(IndexReader reader, int numDocsForVocabulary, int numClusters) {
        this.numDocsForVocabulary = numDocsForVocabulary;
        this.numClusters = numClusters;
        this.reader = reader;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. TODO: write
     *
     * @param reader               the index reader
     * @param lireFeature          lireFeature used
     */
    public BOVWBuilder(IndexReader reader, LireFeature lireFeature) {
        this.reader = reader;
        this.lireFeature = lireFeature;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters).
     * TODO: write
     *
     * @param reader               the index reader
     * @param lireFeature          lireFeature used
     * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary
     */
    public BOVWBuilder(IndexReader reader, LireFeature lireFeature, int numDocsForVocabulary) {
        this.numDocsForVocabulary = numDocsForVocabulary;
        this.reader = reader;
        this.lireFeature = lireFeature;
    }

    /**
     * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary
     * indicates how many documents of the index are used to build the vocabulary (clusters). The numClusters gives
     * the number of clusters k-means should find. Note that this number should be lower than the number of features,
     * otherwise an exception will be thrown while indexing. TODO: write
     *
     * @param reader               the index reader
     * @param lireFeature          lireFeature used
     * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary
     * @param numClusters          the size of the visual vocabulary
     */
    public BOVWBuilder(IndexReader reader, LireFeature lireFeature, int numDocsForVocabulary, int numClusters) {
        this.numDocsForVocabulary = numDocsForVocabulary;
        this.numClusters = numClusters;
        this.reader = reader;
        this.lireFeature = lireFeature;
    }

    protected void init() {
        localFeatureFieldName = lireFeature.getFieldName();
        visualWordsFieldName = lireFeature.getFieldName() + DocumentBuilder.FIELD_NAME_BOVW;
        localFeatureHistFieldName = lireFeature.getFieldName()+ DocumentBuilder.FIELD_NAME_BOVW_VECTOR;
        clusterFile = "./clusters-bovw" + lireFeature.getFeatureName() +  ".dat";
    }

    /**
     * Uses an existing index, where each and every document should have a set of local features. A number of
     * random images (numDocsForVocabulary) is selected and clustered to get a vocabulary of visual words
     * (the cluster means). For all images a histogram on the visual words is created and added to the documents.
     * Pre-existing histograms are deleted, so this method can be used for re-indexing.
     *
     * @throws java.io.IOException
     */
    public void index() throws IOException {
        init();
        df.setMaximumFractionDigits(3);
        // find the documents for building the vocabulary:
        HashSet<Integer> docIDs = selectVocabularyDocs();    //选择全部要进行聚类的文档docment的id
        KMeans k = new KMeans(numClusters);
        // fill the KMeans object:
        LinkedList<double[]> features = new LinkedList<double[]>();
        // Needed for check whether the document is deleted.
        Bits liveDocs = MultiFields.getLiveDocs(reader);
        for (Iterator<Integer> iterator = docIDs.iterator(); iterator.hasNext(); ) {
            int nextDoc = iterator.next();
            if (reader.hasDeletions() && !liveDocs.get(nextDoc)) continue; // if it is deleted, just ignore it.
            Document d = reader.document(nextDoc);   // 取出该文档
            features.clear();
            IndexableField[] fields = d.getFields(localFeatureFieldName);   // 取出sift特征点
            String file = d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0];   // 取出该图片路径名字
            for (int j = 0; j < fields.length; j++) {
                LireFeature f = getFeatureInstance();
                // 取出descriptor
                f.setByteArrayRepresentation(fields[j].binaryValue().bytes, fields[j].binaryValue().offset, fields[j].binaryValue().length);
                // copy the data over to new array ...  没有用
                //double[] feat = new double[f.getDoubleHistogram().length];
                //System.arraycopy(f.getDoubleHistogram(), 0, feat, 0, feat.length);
                features.add(f.getDoubleHistogram());
            }
            k.addImage(file, features);    // 将descriptor与图片相关联
        }
        if (pm != null) { // set to 5 of 100 before clustering starts.
            pm.setProgress(5);
            pm.setNote("Starting clustering");
        }
        if (k.getFeatureCount() < numClusters) {    // 总的特征数小于聚类中心个数,则抛出异常
            // this cannot work. You need more data points than clusters.
            throw new UnsupportedOperationException("Only " + features.size() + " features found to cluster in " + numClusters + ". Try to use less clusters or more images.");
        }
        // do the clustering:
        System.out.println("Number of local features: " + df.format(k.getFeatureCount()));
        System.out.println("Starting clustering ...");
        k.init();        // 聚类中心初始化
        System.out.println("Step.");
        double time = System.currentTimeMillis();
        double laststress = k.clusteringStep();    // 进行聚类,并获得sum of squared error

        if (pm != null) { // set to 8 of 100 after first step.
            pm.setProgress(8);
            pm.setNote("Step 1 finished");
        }

        System.out.println(getDuration(time) + " -> Next step.");
        time = System.currentTimeMillis();
        double newStress = k.clusteringStep();    // 第二步聚类

        if (pm != null) { // set to 11 of 100 after second step.
            pm.setProgress(11);
            pm.setNote("Step 2 finished");
        }

        // critical part: Give the difference in between steps as a constraint for accuracy vs. runtime trade off.
        double threshold = Math.max(20d, (double) k.getFeatureCount() / 1000d);   // 如果两次sse小于20 迭代停止
        System.out.println("Threshold = " + df.format(threshold));
        int cstep = 3;
        while (Math.abs(newStress - laststress) > threshold && cstep < 12) {    // 迭代次数超过12次,迭代停止
            System.out.println(getDuration(time) + " -> Next step. Stress difference ~ |" + (int) newStress + " - " + (int) laststress + "| = " + df.format(Math.abs(newStress - laststress)));
            time = System.currentTimeMillis();
            laststress = newStress;
            newStress = k.clusteringStep();
            if (pm != null) { // set to XX of 100 after second step.
                pm.setProgress(cstep * 3 + 5);
                pm.setNote("Step " + cstep + " finished");
            }
            cstep++;
        }
        // Serializing clusters to a file on the disk ...
        clusters = k.getClusters();    // 得到聚类中心
//        for (int i = 0; i < clusters.length; i++) {
//            Cluster cluster = clusters[i];
//            System.out.print(cluster.getMembers().size() + ", ");
//        }
//        System.out.println();
        Cluster.writeClusters(clusters, clusterFile);  // 将聚类中心点写入文本文件
        //  create & store histograms:
        System.out.println("Creating histograms ...");
        time = System.currentTimeMillis();
//        int[] tmpHist = new int[numClusters];
        @SuppressWarnings("deprecation")
		IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_4_10_2,
                new WhitespaceAnalyzer(Version.LUCENE_4_10_2));
        conf.setOpenMode(OpenMode.CREATE_OR_APPEND);
        IndexWriter iw = new IndexWriter(((DirectoryReader) reader).directory(), conf);
        if (pm != null) { // set to 50 of 100 after clustering.
            pm.setProgress(50);
            pm.setNote("Clustering finished");
        }
        // parallelized indexing
        LinkedList<Thread> threads = new LinkedList<Thread>();  // 线程队列
        int numThreads = 8;     // 设置了8个线程
        // careful: copy reader to RAM for faster access when reading ...
//        reader = IndexReader.open(new RAMDirectory(reader.directory()), true);
        int step = reader.maxDoc() / numThreads;   // 对每个线程分配一定数量的任务
        for (int part = 0; part < numThreads; part++) {
            Indexer indexer = null;
            if (part < numThreads - 1) indexer = new Indexer(part * step, (part + 1) * step, iw, null);
            else indexer = new Indexer(part * step, reader.maxDoc(), iw, pm);
            Thread t = new Thread(indexer);
            threads.add(t);    
            t.start();
        }
        for (Iterator<Thread> iterator = threads.iterator(); iterator.hasNext(); ) {
            Thread next = iterator.next();
            try {
                next.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        if (pm != null) { // set to 50 of 100 after clustering.
            pm.setProgress(95);
            pm.setNote("Indexing finished, optimizing index now.");
        }

        System.out.println(getDuration(time));
        iw.commit();
        // this one does the "old" commit(), it removes the deleted SURF features.
        iw.forceMerge(1);
        iw.close();
        if (pm != null) { // set to 50 of 100 after clustering.
            pm.setProgress(100);
            pm.setNote("Indexing & optimization finished");
            pm.close();
        }
        System.out.println("Finished.");
    }

   // 此函数没有用
    public void indexMissing() throws IOException {
        init();
        // Reading clusters from disk:
        clusters = Cluster.readClusters(clusterFile);
        //  create & store histograms:
        System.out.println("Creating histograms ...");
        LireFeature f = getFeatureInstance();

        // Needed for check whether the document is deleted.
        Bits liveDocs = MultiFields.getLiveDocs(reader);

        // based on bug report from Einav Itamar <[email protected]>
        @SuppressWarnings("deprecation")
		IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_4_10_2,
                new WhitespaceAnalyzer(Version.LUCENE_4_10_2));
        IndexWriter iw = new IndexWriter(((DirectoryReader) reader).directory(), conf);
        for (int i = 0; i < reader.maxDoc(); i++) {
            if (reader.hasDeletions() && !liveDocs.get(i)) continue; // if it is deleted, just ignore it.
            Document d = reader.document(i);
            // Only if there are no values yet:
            if (d.getValues(visualWordsFieldName) == null || d.getValues(visualWordsFieldName).length == 0) {
                createVisualWords(d, f);
                // now write the new one. we use the identifier to update ;)
                iw.updateDocument(new Term(DocumentBuilder.FIELD_NAME_IDENTIFIER, d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0]), d);
            }
        }
        iw.commit();
        // added to permanently remove the deleted docs.
        iw.forceMerge(1);
        iw.close();
        System.out.println("Finished.");
    }

    /**
     * Takes one single document and creates the visual words and adds them to the document. The same document is returned.
     *
     * @param d the document to use for adding the visual words
     * @return
     * @throws IOException
     */
    public Document getVisualWords(Document d) throws IOException {     // 得到文档d所对应的bow特征
        clusters = Cluster.readClusters(clusterFile);   // 读入聚类中心
        LireFeature f = getFeatureInstance();     
        createVisualWords(d, f);    // 创建bow特征

        return d;
    }


    @SuppressWarnings("unused")    // 没有用了
	private void quantize(double[] histogram) {
        double max = 0;
        for (int i = 0; i < histogram.length; i++) {
            max = Math.max(max, histogram[i]);
        }
        for (int i = 0; i < histogram.length; i++) {
            histogram[i] = (int) Math.floor((histogram[i] * 128d) / max);
        }
    }

    /**
     * Find the appropriate cluster for a given feature.
     *
     * @param f
     * @return the index of the cluster.
     */
    private int clusterForFeature(Histogram f) {   // 找到一个特征点最近的聚类中心并返回该聚类中心的下标
        double distance = clusters[0].getDistance(f);
        double tmp;
        int result = 0;
        for (int i = 1; i < clusters.length; i++) {
            tmp = clusters[i].getDistance(f);
            if (tmp < distance) {
                distance = tmp;
                result = i;
            }
        }
        return result;
    }

    private String arrayToVisualWordString(double[] hist) {   // 以这种string类型进行存储,感觉没什么用啊
        StringBuilder sb = new StringBuilder(1024);
        for (int i = 0; i < hist.length; i++) {
            int visualWordIndex = (int) hist[i];
            for (int j = 0; j < visualWordIndex; j++) {
                // sb.append('v');
                sb.append(Integer.toHexString(i));
                sb.append(' ');
            }
        }
        return sb.toString();
    }
        // 选择图片进行聚类
    private HashSet<Integer> selectVocabularyDocs() throws IOException {
        // need to make sure that this is not running forever ...
        int loopCount = 0;
        float maxDocs = reader.maxDoc();    // 返回总文档数量
        int capacity = (int) Math.min(numDocsForVocabulary, maxDocs);
        if (capacity < 0) capacity = (int) (maxDocs / 2);   // 如果是-1 则选择一半文档
        HashSet<Integer> result = new HashSet<Integer>(capacity);
        int tmpDocNumber, tmpIndex;
        LinkedList<Integer> docCandidates = new LinkedList<Integer>();
        // three cases:
        //
        // either it's more or the same number as documents
        if (numDocsForVocabulary >= maxDocs) {   // 指定数量大于已有的,则将已有全部用来聚类
            for (int i = 0; i < maxDocs; i++) {
                result.add(i);
            }
            return result;
        } else if (numDocsForVocabulary >= maxDocs - 100) { // 在[maxDocs-100, maxDocs]之间,
            for (int i = 0; i < maxDocs; i++) {
                result.add(i);         // 先全部加入
            }
            while (result.size() > numDocsForVocabulary) {   // 随机踢出掉多余的图片,使数量为numDocForVocabulary
                result.remove((int) Math.floor(Math.random() * result.size()));
            }
            return result;
        } else {             // 不满足上面几种情况即numDocForVocabulary在[1, maxDocs-100]之间
            for (int i = 0; i < maxDocs; i++) {
                docCandidates.add(i);    // 先将全部加入
            }
            for (int r = 0; r < capacity; r++) { // capacity就等于numDocForVocabulary
                boolean worksFine = false;
                do {
                    tmpIndex = (int) Math.floor(Math.random() * (double) docCandidates.size());
                    tmpDocNumber = docCandidates.get(tmpIndex);
                    docCandidates.remove(tmpIndex);
                    // 该文档是否存在及是否已经包含
                    // check if the selected doc number is valid: not null, not deleted and not already chosen.
                    worksFine = (reader.document(tmpDocNumber) != null) && !result.contains(tmpDocNumber);
                } while (!worksFine);
                result.add(tmpDocNumber);
                // need to make sure that this is not running forever ...
                if (loopCount++ > capacity * 100)
                    throw new UnsupportedOperationException("Could not get the documents, maybe there are not enough documents in the index?");
            }
            return result;
        }
    }

//    protected abstract LireFeature getFeatureInstance();

    protected LireFeature getFeatureInstance() {
        LireFeature result = null;
        try {
            result =  lireFeature.getClass().newInstance();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        return result;
    }

    private class Indexer implements Runnable {     // 建索引的线程类 私有的
        int start, end;
        IndexWriter iw;
        ProgressMonitor pm = null;

        private Indexer(int start, int end, IndexWriter iw, ProgressMonitor pm) {
            this.start = start;
            this.end = end;
            this.iw = iw;
            this.pm = pm;
        }

        public void run() {               // 线程运行函数
            LireFeature f = getFeatureInstance();   // 得到feature的实例
            for (int i = start; i < end; i++) {
                try {
                    Document d = reader.document(i);    // 得到第i个文档
                    createVisualWords(d, f);
                    iw.updateDocument(new Term(DocumentBuilder.FIELD_NAME_IDENTIFIER, d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0]), d);
                    if (pm != null) {
                        double len = (double) (end - start);
                        double percent = (double) (i - start) / len * 45d + 50;
                        pm.setProgress((int) percent);
                        pm.setNote("Creating visual words, ~" + (int) percent + "% finished");
                    }
//                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private void createVisualWords(Document d, LireFeature f)
    {
        double[] tmpHist = new double[numClusters];     
        Arrays.fill(tmpHist, 0d);
        IndexableField[] fields = d.getFields(localFeatureFieldName);
        // remove the fields if they are already there ...
        // 从索引中移除以下两个字段以防已经存在
        d.removeField(visualWordsFieldName);
        d.removeField(localFeatureHistFieldName);

        // find the appropriate cluster for each feature:
        for (int j = 0; j < fields.length; j++) {      // 获取该描述符 
            f.setByteArrayRepresentation(fields[j].binaryValue().bytes, fields[j].binaryValue().offset, fields[j].binaryValue().length);
            tmpHist[clusterForFeature((Histogram) f)]++;      // 得到每一个特征点所对应的最近聚类中心就+1
        }
        //quantize(tmpHist);   // tmpHist就是最终的结果
        d.add(new TextField(visualWordsFieldName, arrayToVisualWordString(tmpHist), Field.Store.YES));    // 以字符串的形式进行存储,没什么用
        d.add(new StoredField(localFeatureHistFieldName, SerializationUtils.toByteArray(tmpHist)));   // 转换成字节类型进行存储
        // remove local features to save some space if requested:
        if (DELETE_LOCAL_FEATURES) {
            d.removeFields(localFeatureFieldName);     // 移除原有的field
        }

        // for debugging ..
//        System.out.println(d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0] + " " + Arrays.toString(tmpHist));
    }

    private String getDuration(double time) {
        double min = (System.currentTimeMillis() - time) / (1000 * 60);
        double sec = (min - Math.floor(min)) * 60;
        return String.format("%02d:%02d", (int) min, (int) sec);
    }

    public void setProgressMonitor(ProgressMonitor pm) {
        this.pm = pm;
    }

}

KMeans.java(包含注释)

package lmc.imageretrieval.imageanalysis.bovw;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

import lmc.imageretrieval.imageanalysis.Histogram;
import lmc.imageretrieval.utils.StatsUtils;

public class KMeans {
    protected List<Image> images = new LinkedList<Image>();
    protected int countAllFeatures = 0, numClusters = 256;
    protected ArrayList<double[]> features = null;
    protected Cluster[] clusters = null;
    protected HashMap<double[], Integer> featureIndex = null;

    public KMeans() {

    }

    public KMeans(int numClusters) {
        this.numClusters = numClusters;
    }

    public void addImage(String identifier, List<double[]> features) {    // 加入image
        images.add(new Image(identifier, features));
        countAllFeatures += features.size();
    }

    public int getFeatureCount() {
        return countAllFeatures;
    }

    public void init() {      // 聚类中心初始化
        // create a set of all features:
        features = new ArrayList<double[]>(countAllFeatures);
        for (Image image : images) {
            if (image.features.size() > 0)         // 将所有的descriptor放入features中
                for (double[] histogram : image.features) {
                    if (!hasNaNs(histogram)) features.add(histogram);
                }
            else {
                System.err.println("Image with no features: " + image.identifier);
            }
        }
        // --- check if there are (i) enough images and (ii) enough features
        if (images.size() < 500) {   // 图片数量小于500 错误
            System.err.println("WARNING: Please note that this approach has been implemented for big data and *a lot of images*. " +
                    "You might not get appropriate results with a small number of images employed for constructing the visual vocabulary.");
        }
        if (features.size() < numClusters*2) {   // 特征点个数不能小于聚类中心的两倍
            System.err.println("WARNING: Please note that the number of local features, in this case " + features.size() + ", is" +
                    "smaller than the recommended minimum number, which is two times the number of visual words, in your case 2*" + numClusters +
                    ". Please adapt your data and either use images with more local features or more images for creating the visual vocabulary.");
        }
        if (features.size() < numClusters + 1) {    //特征点个数不能小于聚类中心+1
            System.err.println("CRITICAL: The number of features is smaller than the number of clusters. This cannot work as there has to be at least one " +
                    "feature per cluster. Aborting process now.");
            System.out.println("images: " + images.size());
            System.out.println("features: " + features.size());
            System.out.println("clusters: " + numClusters);
            System.exit(1);
        }
        // find first clusters:
        clusters = new Cluster[numClusters];          // 初始的聚类中心
        Set<Integer> medians = selectInitialMedians(numClusters);
        assert(medians.size() == numClusters); // this has to be the same ...
        Iterator<Integer> mediansIterator = medians.iterator();
        for (int i = 0; i < clusters.length; i++) {
            double[] descriptor = features.get(mediansIterator.next());
            clusters[i] = new Cluster(new double[descriptor.length]);   // implicitly setting the length of the mean array.
            System.arraycopy(descriptor, 0, clusters[i].mean, 0, descriptor.length);
        }
    }

    protected Set<Integer> selectInitialMedians(int numClusters) {
        return StatsUtils.drawSample(numClusters, features.size());
    }

    /**
     * Do one step and return the overall stress (squared error). You should do this until
     * the error is below a threshold or doesn't change a lot in between two subsequent steps.
     *
     * @return
     */
    public double clusteringStep() {            // 聚类迭代
        for (int i = 0; i < clusters.length; i++) {
            clusters[i].members.clear();             // 清空该聚类中心所有的成员
        }
        reOrganizeFeatures();     // 重新计算每个样本点到聚类中心的距离,重新分配
        recomputeMeans();          // 重新计算聚类中心的大小
        return overallStress();        // 返回sum of squared  迭代结束指标
    }

    protected boolean hasNaNs(double[] histogram) {    // 判断是否有not a number
        boolean hasNaNs = false;
        for (int i = 0; i < histogram.length; i++) {
            if (Double.isNaN(histogram[i])) {
                hasNaNs = true;
                break;
            }
        }
        if (hasNaNs) {
            System.err.println("Found a NaN in init");
//            System.out.println("image.identifier = " + image.identifier);
            for (int j = 0; j < histogram.length; j++) {
                double v = histogram[j];
                System.out.print(v + ", ");
            }
            System.out.println("");
        }
        return hasNaNs;
    }

    /**
     * Re-shuffle all features.
     */
    protected void reOrganizeFeatures() {            // 重新计算每个点到聚类中心的距离,该点归属于哪一个聚类中心
        for (int k = 0; k < features.size(); k++) {     // 看k属于哪个聚类中心最近
            double[] f = features.get(k);
            Cluster best = clusters[0];
            double minDistance = clusters[0].getDistance(f);
            for (int i = 1; i < clusters.length; i++) {
                double v = clusters[i].getDistance(f);   // 采用的是欧式距离
                if (minDistance > v) {
                    best = clusters[i];
                    minDistance = v;
                }
            }
            best.members.add(k);
        }
    }

    /**
     * Computes the mean per cluster (averaged vector)
     */
    protected void recomputeMeans() {        // 重新计算聚类中心
        int length = features.get(0).length;
        for (int i = 0; i < clusters.length; i++) {
            Cluster cluster = clusters[i];
            double[] mean = cluster.mean;
            for (int j = 0; j < length; j++) {
                mean[j] = 0;
                for (Integer member : cluster.members) {
                    mean[j] += features.get(member)[j];
                }
                if (cluster.members.size() > 1)
                    mean[j] = mean[j] / (double) cluster.members.size();
            }
            if (cluster.members.size() == 1) {         // 该聚类中心只含有一个点
                System.err.println("** There is just one member in cluster " + i);
            } else if (cluster.members.size() < 1) {   // 该聚类中心没有点
                System.err.println("** There is NO member in cluster " + i);
                // fill it with a random member?!?
                int index = (int) Math.floor(Math.random()*features.size());    // 重新随机选择一个点作为该聚类中心
                System.arraycopy(features.get(index), 0, clusters[i].mean, 0, clusters[i].mean.length);
            }

        }
    }

    /**
     * Squared error in classification.
     *
     * @return
     */
    protected double overallStress() {         // 计算聚类中的sum of squared
        double v = 0;
        int length = features.get(0).length;
        for (int i = 0; i < clusters.length; i++) {
            for (Integer member : clusters[i].members) {
                float tmpStress = 0;
                for (int j = 0; j < length; j++) {
//                    if (Float.isNaN(features.get(member).descriptor[j])) System.err.println("Error: there is a NaN in cluster " + i + " at member " + member);
                    tmpStress += Math.abs(clusters[i].mean[j] - features.get(member)[j]);
                }
                v += tmpStress;
            }
        }
        return v;
    }

    public Cluster[] getClusters() {
        return clusters;
    }

    public List<Image> getImages() {
        return images;
    }

    /**
     * Set the number of desired clusters.
     *
     * @return
     */
    public int getNumClusters() {
        return numClusters;
    }

    public void setNumClusters(int numClusters) {
        this.numClusters = numClusters;
    }

    private HashMap<double[], Integer> createIndex() {
        featureIndex = new HashMap<double[], Integer>(features.size());
        for (int i = 0; i < clusters.length; i++) {
            Cluster cluster = clusters[i];
            for (Iterator<Integer> fidit = cluster.members.iterator(); fidit.hasNext(); ) {
                int fid = fidit.next();
                featureIndex.put(features.get(fid), i);
            }
        }
        return featureIndex;
    }

    /**
     * Used to find the cluster of a feature actually used in the clustering process (so
     * it is known by the k-means class).
     *
     * @param f the feature to search for
     * @return the index of the Cluster
     */
    public int getClusterOfFeature(Histogram f) {
        if (featureIndex == null) createIndex();
        return featureIndex.get(f);
    }
}

class Image {
    public List<double[]> features;
    public String identifier;
    public float[] localFeatureHistogram = null;
    private final int QUANT_MAX_HISTOGRAM = 256;

    Image(String identifier, List<double[]> features) {
        this.features = new LinkedList<double[]>();
        this.features.addAll(features);
        this.identifier = identifier;
    }

    public float[] getLocalFeatureHistogram() {
        return localFeatureHistogram;
    }

    public void setLocalFeatureHistogram(float[] localFeatureHistogram) {
        this.localFeatureHistogram = localFeatureHistogram;
    }

    public void initHistogram(int bins) {
        localFeatureHistogram = new float[bins];
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            localFeatureHistogram[i] = 0;
        }
    }

    public void normalizeFeatureHistogram() {         // 对聚类中心进行归一化
        float max = 0;
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            max = Math.max(localFeatureHistogram[i], max);
        }
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            localFeatureHistogram[i] = (localFeatureHistogram[i] * QUANT_MAX_HISTOGRAM) / max;
        }
    }

    public void printHistogram() {
        for (int i = 0; i < localFeatureHistogram.length; i++) {
            System.out.print(localFeatureHistogram[i] + " ");

        }
        System.out.println("");
    }
}

Cluster.java(包含注释)

package lmc.imageretrieval.imageanalysis.bovw;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;

import lmc.imageretrieval.imageanalysis.Histogram;
import lmc.imageretrieval.utils.MetricsUtils;
import lmc.imageretrieval.utils.SerializationUtils;

public class Cluster implements Comparable<Object> {
    double[] mean;
    HashSet<Integer> members = new HashSet<Integer>();

    private double stress = 0;

    public Cluster() {
        this.mean = new double[4 * 4 * 8];
        Arrays.fill(mean, 0f);
    }

    public Cluster(double[] mean) {
        this.mean = mean;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(512);
        for (Integer integer : members) {
            sb.append(integer);
            sb.append(", ");
        }
        for (int i = 0; i < mean.length; i++) {
            sb.append(mean[i]);
            sb.append(';');
        }
        return sb.toString();
    }

    public int compareTo(Object o) {
        return ((Cluster) o).members.size() - members.size();
    }

    public double getDistance(Histogram f) {
        return getDistance(f.getDoubleHistogram());
    }

    public double getDistance(double[] f) {
//        L1
//        return MetricsUtils.distL1(mean, f);

//        L2
        return MetricsUtils.distL2(mean, f);
    }

    /**
     * Creates a byte array representation from the clusters mean.
     *
     * @return the clusters mean as byte array.
     */
    public byte[] getByteRepresentation() {
        return SerializationUtils.toByteArray(mean);
    }

    public void setByteRepresentation(byte[] data) {
        mean = SerializationUtils.toDoubleArray(data);
    }

    public static void writeClusters(Cluster[] clusters, String file) throws IOException {   // 将聚类中心写入磁盘上
        FileOutputStream fout = new FileOutputStream(file);
        fout.write(SerializationUtils.toBytes(clusters.length));   // 聚类中心个数
        fout.write(SerializationUtils.toBytes((clusters[0].getMean()).length));  // 聚类中心点的长度
        for (int i = 0; i < clusters.length; i++) {
            fout.write(clusters[i].getByteRepresentation());   // 写入每个聚类中心
        }
        fout.close();
    }

    // TODO: re-visit here to make the length variable (depending on the actual feature size).
    public static Cluster[] readClusters(String file) throws IOException {    // 从磁盘上读取聚类中心
        FileInputStream fin = new FileInputStream(file);
        byte[] tmp = new byte[4];
        fin.read(tmp, 0, 4);
        Cluster[] result = new Cluster[SerializationUtils.toInt(tmp)];
        fin.read(tmp, 0, 4);
        int size = SerializationUtils.toInt(tmp);
        tmp = new byte[size * 8];
        for (int i = 0; i < result.length; i++) {
            int bytesRead = fin.read(tmp, 0, size * 8);
            if (bytesRead != size * 8) System.err.println("Didn't read enough bytes ...");
            result[i] = new Cluster();
            result[i].setByteRepresentation(tmp);
        }
        fin.close();
        return result;
    }

    public double getStress() {
        return stress;
    }

    public void setStress(double stress) {
        this.stress = stress;
    }

    public HashSet<Integer> getMembers() {
        return members;
    }

    public void setMembers(HashSet<Integer> members) {
        this.members = members;
    }

    /**
     * Returns the cluster mean
     *
     * @return the cluster mean vector
     */
    public double[] getMean() {
        return mean;
    }
}


你可能感兴趣的:(Lucene,源码解析,LIRE,以图搜图)