Lucene image retrieval是以图搜图的java开源框架,这几天没什么事,就读了点源码,并写了点注释,特在这分享给大家。
这里主要给出的是BOVWBuilder.java、Kmeans.java及Cluster.java。就是用词频对特征进行编码,用到是BOF(bag of feature)模型,原理就是提取N张图片的特征(比如sift),放在一起就可以得到矩阵,然后对矩阵进行kmeans聚类,就会到到若干个聚类中心;对于新来的一副图像,我们分别计算该特征点与那个聚类中心最近,这样该聚类中心的量值就加1,这样就可以编码得到与聚类中心个数想等的维数向量。
一切都从BOVWBuilder中index函数开始...
BOVWBuilder.java(包含注释)
package lmc.imageretrieval.imageanalysis.bovw; import java.io.File; import java.io.IOException; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import javax.swing.ProgressMonitor; import lmc.imageretrieval.imageanalysis.Histogram; import lmc.imageretrieval.imageanalysis.LireFeature; import lmc.imageretrieval.tools.DocumentBuilder; import lmc.imageretrieval.utils.SerializationUtils; import org.apache.lucene.analysis.core.WhitespaceAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.StoredField; import org.apache.lucene.document.TextField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig.OpenMode; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.Term; import org.apache.lucene.util.Bits; import org.apache.lucene.util.Version; public class BOVWBuilder { IndexReader reader; // number of documents used to build the vocabulary / clusters. private int numDocsForVocabulary = 500; private int numClusters = 512; private Cluster[] clusters = null; DecimalFormat df = (DecimalFormat) NumberFormat.getNumberInstance(); private ProgressMonitor pm = null; protected LireFeature lireFeature; protected String localFeatureFieldName; protected String visualWordsFieldName; protected String localFeatureHistFieldName; protected String clusterFile; public static boolean DELETE_LOCAL_FEATURES = true; /** * * @param reader * @deprecated */ public BOVWBuilder(IndexReader reader) { this.reader = reader; } /** * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary * indicates how many documents of the index are used to build the vocabulary (clusters). * * @param reader the reader used to open the Lucene index, * @param numDocsForVocabulary gives the number of documents for building the vocabulary (clusters). * @deprecated */ public BOVWBuilder(IndexReader reader, int numDocsForVocabulary) { this.reader = reader; this.numDocsForVocabulary = numDocsForVocabulary; } /** * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary * indicates how many documents of the index are used to build the vocabulary (clusters). The numClusters gives * the number of clusters k-means should find. Note that this number should be lower than the number of features, * otherwise an exception will be thrown while indexing. * * @param reader the index reader * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary * @param numClusters the size of the visual vocabulary * @deprecated */ public BOVWBuilder(IndexReader reader, int numDocsForVocabulary, int numClusters) { this.numDocsForVocabulary = numDocsForVocabulary; this.numClusters = numClusters; this.reader = reader; } /** * Creates a new instance of the BOVWBuilder using the given reader. TODO: write * * @param reader the index reader * @param lireFeature lireFeature used */ public BOVWBuilder(IndexReader reader, LireFeature lireFeature) { this.reader = reader; this.lireFeature = lireFeature; } /** * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary * indicates how many documents of the index are used to build the vocabulary (clusters). * TODO: write * * @param reader the index reader * @param lireFeature lireFeature used * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary */ public BOVWBuilder(IndexReader reader, LireFeature lireFeature, int numDocsForVocabulary) { this.numDocsForVocabulary = numDocsForVocabulary; this.reader = reader; this.lireFeature = lireFeature; } /** * Creates a new instance of the BOVWBuilder using the given reader. The numDocsForVocabulary * indicates how many documents of the index are used to build the vocabulary (clusters). The numClusters gives * the number of clusters k-means should find. Note that this number should be lower than the number of features, * otherwise an exception will be thrown while indexing. TODO: write * * @param reader the index reader * @param lireFeature lireFeature used * @param numDocsForVocabulary the number of documents that should be sampled for building the visual vocabulary * @param numClusters the size of the visual vocabulary */ public BOVWBuilder(IndexReader reader, LireFeature lireFeature, int numDocsForVocabulary, int numClusters) { this.numDocsForVocabulary = numDocsForVocabulary; this.numClusters = numClusters; this.reader = reader; this.lireFeature = lireFeature; } protected void init() { localFeatureFieldName = lireFeature.getFieldName(); visualWordsFieldName = lireFeature.getFieldName() + DocumentBuilder.FIELD_NAME_BOVW; localFeatureHistFieldName = lireFeature.getFieldName()+ DocumentBuilder.FIELD_NAME_BOVW_VECTOR; clusterFile = "./clusters-bovw" + lireFeature.getFeatureName() + ".dat"; } /** * Uses an existing index, where each and every document should have a set of local features. A number of * random images (numDocsForVocabulary) is selected and clustered to get a vocabulary of visual words * (the cluster means). For all images a histogram on the visual words is created and added to the documents. * Pre-existing histograms are deleted, so this method can be used for re-indexing. * * @throws java.io.IOException */ public void index() throws IOException { init(); df.setMaximumFractionDigits(3); // find the documents for building the vocabulary: HashSet<Integer> docIDs = selectVocabularyDocs(); //选择全部要进行聚类的文档docment的id KMeans k = new KMeans(numClusters); // fill the KMeans object: LinkedList<double[]> features = new LinkedList<double[]>(); // Needed for check whether the document is deleted. Bits liveDocs = MultiFields.getLiveDocs(reader); for (Iterator<Integer> iterator = docIDs.iterator(); iterator.hasNext(); ) { int nextDoc = iterator.next(); if (reader.hasDeletions() && !liveDocs.get(nextDoc)) continue; // if it is deleted, just ignore it. Document d = reader.document(nextDoc); // 取出该文档 features.clear(); IndexableField[] fields = d.getFields(localFeatureFieldName); // 取出sift特征点 String file = d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0]; // 取出该图片路径名字 for (int j = 0; j < fields.length; j++) { LireFeature f = getFeatureInstance(); // 取出descriptor f.setByteArrayRepresentation(fields[j].binaryValue().bytes, fields[j].binaryValue().offset, fields[j].binaryValue().length); // copy the data over to new array ... 没有用 //double[] feat = new double[f.getDoubleHistogram().length]; //System.arraycopy(f.getDoubleHistogram(), 0, feat, 0, feat.length); features.add(f.getDoubleHistogram()); } k.addImage(file, features); // 将descriptor与图片相关联 } if (pm != null) { // set to 5 of 100 before clustering starts. pm.setProgress(5); pm.setNote("Starting clustering"); } if (k.getFeatureCount() < numClusters) { // 总的特征数小于聚类中心个数,则抛出异常 // this cannot work. You need more data points than clusters. throw new UnsupportedOperationException("Only " + features.size() + " features found to cluster in " + numClusters + ". Try to use less clusters or more images."); } // do the clustering: System.out.println("Number of local features: " + df.format(k.getFeatureCount())); System.out.println("Starting clustering ..."); k.init(); // 聚类中心初始化 System.out.println("Step."); double time = System.currentTimeMillis(); double laststress = k.clusteringStep(); // 进行聚类,并获得sum of squared error if (pm != null) { // set to 8 of 100 after first step. pm.setProgress(8); pm.setNote("Step 1 finished"); } System.out.println(getDuration(time) + " -> Next step."); time = System.currentTimeMillis(); double newStress = k.clusteringStep(); // 第二步聚类 if (pm != null) { // set to 11 of 100 after second step. pm.setProgress(11); pm.setNote("Step 2 finished"); } // critical part: Give the difference in between steps as a constraint for accuracy vs. runtime trade off. double threshold = Math.max(20d, (double) k.getFeatureCount() / 1000d); // 如果两次sse小于20 迭代停止 System.out.println("Threshold = " + df.format(threshold)); int cstep = 3; while (Math.abs(newStress - laststress) > threshold && cstep < 12) { // 迭代次数超过12次,迭代停止 System.out.println(getDuration(time) + " -> Next step. Stress difference ~ |" + (int) newStress + " - " + (int) laststress + "| = " + df.format(Math.abs(newStress - laststress))); time = System.currentTimeMillis(); laststress = newStress; newStress = k.clusteringStep(); if (pm != null) { // set to XX of 100 after second step. pm.setProgress(cstep * 3 + 5); pm.setNote("Step " + cstep + " finished"); } cstep++; } // Serializing clusters to a file on the disk ... clusters = k.getClusters(); // 得到聚类中心 // for (int i = 0; i < clusters.length; i++) { // Cluster cluster = clusters[i]; // System.out.print(cluster.getMembers().size() + ", "); // } // System.out.println(); Cluster.writeClusters(clusters, clusterFile); // 将聚类中心点写入文本文件 // create & store histograms: System.out.println("Creating histograms ..."); time = System.currentTimeMillis(); // int[] tmpHist = new int[numClusters]; @SuppressWarnings("deprecation") IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_4_10_2, new WhitespaceAnalyzer(Version.LUCENE_4_10_2)); conf.setOpenMode(OpenMode.CREATE_OR_APPEND); IndexWriter iw = new IndexWriter(((DirectoryReader) reader).directory(), conf); if (pm != null) { // set to 50 of 100 after clustering. pm.setProgress(50); pm.setNote("Clustering finished"); } // parallelized indexing LinkedList<Thread> threads = new LinkedList<Thread>(); // 线程队列 int numThreads = 8; // 设置了8个线程 // careful: copy reader to RAM for faster access when reading ... // reader = IndexReader.open(new RAMDirectory(reader.directory()), true); int step = reader.maxDoc() / numThreads; // 对每个线程分配一定数量的任务 for (int part = 0; part < numThreads; part++) { Indexer indexer = null; if (part < numThreads - 1) indexer = new Indexer(part * step, (part + 1) * step, iw, null); else indexer = new Indexer(part * step, reader.maxDoc(), iw, pm); Thread t = new Thread(indexer); threads.add(t); t.start(); } for (Iterator<Thread> iterator = threads.iterator(); iterator.hasNext(); ) { Thread next = iterator.next(); try { next.join(); } catch (InterruptedException e) { e.printStackTrace(); } } if (pm != null) { // set to 50 of 100 after clustering. pm.setProgress(95); pm.setNote("Indexing finished, optimizing index now."); } System.out.println(getDuration(time)); iw.commit(); // this one does the "old" commit(), it removes the deleted SURF features. iw.forceMerge(1); iw.close(); if (pm != null) { // set to 50 of 100 after clustering. pm.setProgress(100); pm.setNote("Indexing & optimization finished"); pm.close(); } System.out.println("Finished."); } // 此函数没有用 public void indexMissing() throws IOException { init(); // Reading clusters from disk: clusters = Cluster.readClusters(clusterFile); // create & store histograms: System.out.println("Creating histograms ..."); LireFeature f = getFeatureInstance(); // Needed for check whether the document is deleted. Bits liveDocs = MultiFields.getLiveDocs(reader); // based on bug report from Einav Itamar <[email protected]> @SuppressWarnings("deprecation") IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_4_10_2, new WhitespaceAnalyzer(Version.LUCENE_4_10_2)); IndexWriter iw = new IndexWriter(((DirectoryReader) reader).directory(), conf); for (int i = 0; i < reader.maxDoc(); i++) { if (reader.hasDeletions() && !liveDocs.get(i)) continue; // if it is deleted, just ignore it. Document d = reader.document(i); // Only if there are no values yet: if (d.getValues(visualWordsFieldName) == null || d.getValues(visualWordsFieldName).length == 0) { createVisualWords(d, f); // now write the new one. we use the identifier to update ;) iw.updateDocument(new Term(DocumentBuilder.FIELD_NAME_IDENTIFIER, d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0]), d); } } iw.commit(); // added to permanently remove the deleted docs. iw.forceMerge(1); iw.close(); System.out.println("Finished."); } /** * Takes one single document and creates the visual words and adds them to the document. The same document is returned. * * @param d the document to use for adding the visual words * @return * @throws IOException */ public Document getVisualWords(Document d) throws IOException { // 得到文档d所对应的bow特征 clusters = Cluster.readClusters(clusterFile); // 读入聚类中心 LireFeature f = getFeatureInstance(); createVisualWords(d, f); // 创建bow特征 return d; } @SuppressWarnings("unused") // 没有用了 private void quantize(double[] histogram) { double max = 0; for (int i = 0; i < histogram.length; i++) { max = Math.max(max, histogram[i]); } for (int i = 0; i < histogram.length; i++) { histogram[i] = (int) Math.floor((histogram[i] * 128d) / max); } } /** * Find the appropriate cluster for a given feature. * * @param f * @return the index of the cluster. */ private int clusterForFeature(Histogram f) { // 找到一个特征点最近的聚类中心并返回该聚类中心的下标 double distance = clusters[0].getDistance(f); double tmp; int result = 0; for (int i = 1; i < clusters.length; i++) { tmp = clusters[i].getDistance(f); if (tmp < distance) { distance = tmp; result = i; } } return result; } private String arrayToVisualWordString(double[] hist) { // 以这种string类型进行存储,感觉没什么用啊 StringBuilder sb = new StringBuilder(1024); for (int i = 0; i < hist.length; i++) { int visualWordIndex = (int) hist[i]; for (int j = 0; j < visualWordIndex; j++) { // sb.append('v'); sb.append(Integer.toHexString(i)); sb.append(' '); } } return sb.toString(); } // 选择图片进行聚类 private HashSet<Integer> selectVocabularyDocs() throws IOException { // need to make sure that this is not running forever ... int loopCount = 0; float maxDocs = reader.maxDoc(); // 返回总文档数量 int capacity = (int) Math.min(numDocsForVocabulary, maxDocs); if (capacity < 0) capacity = (int) (maxDocs / 2); // 如果是-1 则选择一半文档 HashSet<Integer> result = new HashSet<Integer>(capacity); int tmpDocNumber, tmpIndex; LinkedList<Integer> docCandidates = new LinkedList<Integer>(); // three cases: // // either it's more or the same number as documents if (numDocsForVocabulary >= maxDocs) { // 指定数量大于已有的,则将已有全部用来聚类 for (int i = 0; i < maxDocs; i++) { result.add(i); } return result; } else if (numDocsForVocabulary >= maxDocs - 100) { // 在[maxDocs-100, maxDocs]之间, for (int i = 0; i < maxDocs; i++) { result.add(i); // 先全部加入 } while (result.size() > numDocsForVocabulary) { // 随机踢出掉多余的图片,使数量为numDocForVocabulary result.remove((int) Math.floor(Math.random() * result.size())); } return result; } else { // 不满足上面几种情况即numDocForVocabulary在[1, maxDocs-100]之间 for (int i = 0; i < maxDocs; i++) { docCandidates.add(i); // 先将全部加入 } for (int r = 0; r < capacity; r++) { // capacity就等于numDocForVocabulary boolean worksFine = false; do { tmpIndex = (int) Math.floor(Math.random() * (double) docCandidates.size()); tmpDocNumber = docCandidates.get(tmpIndex); docCandidates.remove(tmpIndex); // 该文档是否存在及是否已经包含 // check if the selected doc number is valid: not null, not deleted and not already chosen. worksFine = (reader.document(tmpDocNumber) != null) && !result.contains(tmpDocNumber); } while (!worksFine); result.add(tmpDocNumber); // need to make sure that this is not running forever ... if (loopCount++ > capacity * 100) throw new UnsupportedOperationException("Could not get the documents, maybe there are not enough documents in the index?"); } return result; } } // protected abstract LireFeature getFeatureInstance(); protected LireFeature getFeatureInstance() { LireFeature result = null; try { result = lireFeature.getClass().newInstance(); } catch (InstantiationException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } return result; } private class Indexer implements Runnable { // 建索引的线程类 私有的 int start, end; IndexWriter iw; ProgressMonitor pm = null; private Indexer(int start, int end, IndexWriter iw, ProgressMonitor pm) { this.start = start; this.end = end; this.iw = iw; this.pm = pm; } public void run() { // 线程运行函数 LireFeature f = getFeatureInstance(); // 得到feature的实例 for (int i = start; i < end; i++) { try { Document d = reader.document(i); // 得到第i个文档 createVisualWords(d, f); iw.updateDocument(new Term(DocumentBuilder.FIELD_NAME_IDENTIFIER, d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0]), d); if (pm != null) { double len = (double) (end - start); double percent = (double) (i - start) / len * 45d + 50; pm.setProgress((int) percent); pm.setNote("Creating visual words, ~" + (int) percent + "% finished"); } // } } catch (IOException e) { e.printStackTrace(); } } } } private void createVisualWords(Document d, LireFeature f) { double[] tmpHist = new double[numClusters]; Arrays.fill(tmpHist, 0d); IndexableField[] fields = d.getFields(localFeatureFieldName); // remove the fields if they are already there ... // 从索引中移除以下两个字段以防已经存在 d.removeField(visualWordsFieldName); d.removeField(localFeatureHistFieldName); // find the appropriate cluster for each feature: for (int j = 0; j < fields.length; j++) { // 获取该描述符 f.setByteArrayRepresentation(fields[j].binaryValue().bytes, fields[j].binaryValue().offset, fields[j].binaryValue().length); tmpHist[clusterForFeature((Histogram) f)]++; // 得到每一个特征点所对应的最近聚类中心就+1 } //quantize(tmpHist); // tmpHist就是最终的结果 d.add(new TextField(visualWordsFieldName, arrayToVisualWordString(tmpHist), Field.Store.YES)); // 以字符串的形式进行存储,没什么用 d.add(new StoredField(localFeatureHistFieldName, SerializationUtils.toByteArray(tmpHist))); // 转换成字节类型进行存储 // remove local features to save some space if requested: if (DELETE_LOCAL_FEATURES) { d.removeFields(localFeatureFieldName); // 移除原有的field } // for debugging .. // System.out.println(d.getValues(DocumentBuilder.FIELD_NAME_IDENTIFIER)[0] + " " + Arrays.toString(tmpHist)); } private String getDuration(double time) { double min = (System.currentTimeMillis() - time) / (1000 * 60); double sec = (min - Math.floor(min)) * 60; return String.format("%02d:%02d", (int) min, (int) sec); } public void setProgressMonitor(ProgressMonitor pm) { this.pm = pm; } }
KMeans.java(包含注释)
package lmc.imageretrieval.imageanalysis.bovw; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Set; import lmc.imageretrieval.imageanalysis.Histogram; import lmc.imageretrieval.utils.StatsUtils; public class KMeans { protected List<Image> images = new LinkedList<Image>(); protected int countAllFeatures = 0, numClusters = 256; protected ArrayList<double[]> features = null; protected Cluster[] clusters = null; protected HashMap<double[], Integer> featureIndex = null; public KMeans() { } public KMeans(int numClusters) { this.numClusters = numClusters; } public void addImage(String identifier, List<double[]> features) { // 加入image images.add(new Image(identifier, features)); countAllFeatures += features.size(); } public int getFeatureCount() { return countAllFeatures; } public void init() { // 聚类中心初始化 // create a set of all features: features = new ArrayList<double[]>(countAllFeatures); for (Image image : images) { if (image.features.size() > 0) // 将所有的descriptor放入features中 for (double[] histogram : image.features) { if (!hasNaNs(histogram)) features.add(histogram); } else { System.err.println("Image with no features: " + image.identifier); } } // --- check if there are (i) enough images and (ii) enough features if (images.size() < 500) { // 图片数量小于500 错误 System.err.println("WARNING: Please note that this approach has been implemented for big data and *a lot of images*. " + "You might not get appropriate results with a small number of images employed for constructing the visual vocabulary."); } if (features.size() < numClusters*2) { // 特征点个数不能小于聚类中心的两倍 System.err.println("WARNING: Please note that the number of local features, in this case " + features.size() + ", is" + "smaller than the recommended minimum number, which is two times the number of visual words, in your case 2*" + numClusters + ". Please adapt your data and either use images with more local features or more images for creating the visual vocabulary."); } if (features.size() < numClusters + 1) { //特征点个数不能小于聚类中心+1 System.err.println("CRITICAL: The number of features is smaller than the number of clusters. This cannot work as there has to be at least one " + "feature per cluster. Aborting process now."); System.out.println("images: " + images.size()); System.out.println("features: " + features.size()); System.out.println("clusters: " + numClusters); System.exit(1); } // find first clusters: clusters = new Cluster[numClusters]; // 初始的聚类中心 Set<Integer> medians = selectInitialMedians(numClusters); assert(medians.size() == numClusters); // this has to be the same ... Iterator<Integer> mediansIterator = medians.iterator(); for (int i = 0; i < clusters.length; i++) { double[] descriptor = features.get(mediansIterator.next()); clusters[i] = new Cluster(new double[descriptor.length]); // implicitly setting the length of the mean array. System.arraycopy(descriptor, 0, clusters[i].mean, 0, descriptor.length); } } protected Set<Integer> selectInitialMedians(int numClusters) { return StatsUtils.drawSample(numClusters, features.size()); } /** * Do one step and return the overall stress (squared error). You should do this until * the error is below a threshold or doesn't change a lot in between two subsequent steps. * * @return */ public double clusteringStep() { // 聚类迭代 for (int i = 0; i < clusters.length; i++) { clusters[i].members.clear(); // 清空该聚类中心所有的成员 } reOrganizeFeatures(); // 重新计算每个样本点到聚类中心的距离,重新分配 recomputeMeans(); // 重新计算聚类中心的大小 return overallStress(); // 返回sum of squared 迭代结束指标 } protected boolean hasNaNs(double[] histogram) { // 判断是否有not a number boolean hasNaNs = false; for (int i = 0; i < histogram.length; i++) { if (Double.isNaN(histogram[i])) { hasNaNs = true; break; } } if (hasNaNs) { System.err.println("Found a NaN in init"); // System.out.println("image.identifier = " + image.identifier); for (int j = 0; j < histogram.length; j++) { double v = histogram[j]; System.out.print(v + ", "); } System.out.println(""); } return hasNaNs; } /** * Re-shuffle all features. */ protected void reOrganizeFeatures() { // 重新计算每个点到聚类中心的距离,该点归属于哪一个聚类中心 for (int k = 0; k < features.size(); k++) { // 看k属于哪个聚类中心最近 double[] f = features.get(k); Cluster best = clusters[0]; double minDistance = clusters[0].getDistance(f); for (int i = 1; i < clusters.length; i++) { double v = clusters[i].getDistance(f); // 采用的是欧式距离 if (minDistance > v) { best = clusters[i]; minDistance = v; } } best.members.add(k); } } /** * Computes the mean per cluster (averaged vector) */ protected void recomputeMeans() { // 重新计算聚类中心 int length = features.get(0).length; for (int i = 0; i < clusters.length; i++) { Cluster cluster = clusters[i]; double[] mean = cluster.mean; for (int j = 0; j < length; j++) { mean[j] = 0; for (Integer member : cluster.members) { mean[j] += features.get(member)[j]; } if (cluster.members.size() > 1) mean[j] = mean[j] / (double) cluster.members.size(); } if (cluster.members.size() == 1) { // 该聚类中心只含有一个点 System.err.println("** There is just one member in cluster " + i); } else if (cluster.members.size() < 1) { // 该聚类中心没有点 System.err.println("** There is NO member in cluster " + i); // fill it with a random member?!? int index = (int) Math.floor(Math.random()*features.size()); // 重新随机选择一个点作为该聚类中心 System.arraycopy(features.get(index), 0, clusters[i].mean, 0, clusters[i].mean.length); } } } /** * Squared error in classification. * * @return */ protected double overallStress() { // 计算聚类中的sum of squared double v = 0; int length = features.get(0).length; for (int i = 0; i < clusters.length; i++) { for (Integer member : clusters[i].members) { float tmpStress = 0; for (int j = 0; j < length; j++) { // if (Float.isNaN(features.get(member).descriptor[j])) System.err.println("Error: there is a NaN in cluster " + i + " at member " + member); tmpStress += Math.abs(clusters[i].mean[j] - features.get(member)[j]); } v += tmpStress; } } return v; } public Cluster[] getClusters() { return clusters; } public List<Image> getImages() { return images; } /** * Set the number of desired clusters. * * @return */ public int getNumClusters() { return numClusters; } public void setNumClusters(int numClusters) { this.numClusters = numClusters; } private HashMap<double[], Integer> createIndex() { featureIndex = new HashMap<double[], Integer>(features.size()); for (int i = 0; i < clusters.length; i++) { Cluster cluster = clusters[i]; for (Iterator<Integer> fidit = cluster.members.iterator(); fidit.hasNext(); ) { int fid = fidit.next(); featureIndex.put(features.get(fid), i); } } return featureIndex; } /** * Used to find the cluster of a feature actually used in the clustering process (so * it is known by the k-means class). * * @param f the feature to search for * @return the index of the Cluster */ public int getClusterOfFeature(Histogram f) { if (featureIndex == null) createIndex(); return featureIndex.get(f); } } class Image { public List<double[]> features; public String identifier; public float[] localFeatureHistogram = null; private final int QUANT_MAX_HISTOGRAM = 256; Image(String identifier, List<double[]> features) { this.features = new LinkedList<double[]>(); this.features.addAll(features); this.identifier = identifier; } public float[] getLocalFeatureHistogram() { return localFeatureHistogram; } public void setLocalFeatureHistogram(float[] localFeatureHistogram) { this.localFeatureHistogram = localFeatureHistogram; } public void initHistogram(int bins) { localFeatureHistogram = new float[bins]; for (int i = 0; i < localFeatureHistogram.length; i++) { localFeatureHistogram[i] = 0; } } public void normalizeFeatureHistogram() { // 对聚类中心进行归一化 float max = 0; for (int i = 0; i < localFeatureHistogram.length; i++) { max = Math.max(localFeatureHistogram[i], max); } for (int i = 0; i < localFeatureHistogram.length; i++) { localFeatureHistogram[i] = (localFeatureHistogram[i] * QUANT_MAX_HISTOGRAM) / max; } } public void printHistogram() { for (int i = 0; i < localFeatureHistogram.length; i++) { System.out.print(localFeatureHistogram[i] + " "); } System.out.println(""); } }
Cluster.java(包含注释)
package lmc.imageretrieval.imageanalysis.bovw; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.util.Arrays; import java.util.HashSet; import lmc.imageretrieval.imageanalysis.Histogram; import lmc.imageretrieval.utils.MetricsUtils; import lmc.imageretrieval.utils.SerializationUtils; public class Cluster implements Comparable<Object> { double[] mean; HashSet<Integer> members = new HashSet<Integer>(); private double stress = 0; public Cluster() { this.mean = new double[4 * 4 * 8]; Arrays.fill(mean, 0f); } public Cluster(double[] mean) { this.mean = mean; } public String toString() { StringBuilder sb = new StringBuilder(512); for (Integer integer : members) { sb.append(integer); sb.append(", "); } for (int i = 0; i < mean.length; i++) { sb.append(mean[i]); sb.append(';'); } return sb.toString(); } public int compareTo(Object o) { return ((Cluster) o).members.size() - members.size(); } public double getDistance(Histogram f) { return getDistance(f.getDoubleHistogram()); } public double getDistance(double[] f) { // L1 // return MetricsUtils.distL1(mean, f); // L2 return MetricsUtils.distL2(mean, f); } /** * Creates a byte array representation from the clusters mean. * * @return the clusters mean as byte array. */ public byte[] getByteRepresentation() { return SerializationUtils.toByteArray(mean); } public void setByteRepresentation(byte[] data) { mean = SerializationUtils.toDoubleArray(data); } public static void writeClusters(Cluster[] clusters, String file) throws IOException { // 将聚类中心写入磁盘上 FileOutputStream fout = new FileOutputStream(file); fout.write(SerializationUtils.toBytes(clusters.length)); // 聚类中心个数 fout.write(SerializationUtils.toBytes((clusters[0].getMean()).length)); // 聚类中心点的长度 for (int i = 0; i < clusters.length; i++) { fout.write(clusters[i].getByteRepresentation()); // 写入每个聚类中心 } fout.close(); } // TODO: re-visit here to make the length variable (depending on the actual feature size). public static Cluster[] readClusters(String file) throws IOException { // 从磁盘上读取聚类中心 FileInputStream fin = new FileInputStream(file); byte[] tmp = new byte[4]; fin.read(tmp, 0, 4); Cluster[] result = new Cluster[SerializationUtils.toInt(tmp)]; fin.read(tmp, 0, 4); int size = SerializationUtils.toInt(tmp); tmp = new byte[size * 8]; for (int i = 0; i < result.length; i++) { int bytesRead = fin.read(tmp, 0, size * 8); if (bytesRead != size * 8) System.err.println("Didn't read enough bytes ..."); result[i] = new Cluster(); result[i].setByteRepresentation(tmp); } fin.close(); return result; } public double getStress() { return stress; } public void setStress(double stress) { this.stress = stress; } public HashSet<Integer> getMembers() { return members; } public void setMembers(HashSet<Integer> members) { this.members = members; } /** * Returns the cluster mean * * @return the cluster mean vector */ public double[] getMean() { return mean; } }