深度学习-文档分类

 
  
本文主要是用ParagraphVectors方法做文档分类,训练数据有一些带类别的文档,预测没有类别的文档属于哪个类别。这里简单说下ParagraphVectors模型,每篇文档映射在一个唯一的向量上,由矩阵中的一列表示,每个word则类似的被映射到向量上,这个向量由另一个矩阵的列表示。使用连接方式获得新word的预测,可以说ParagraphVectors是在word2vec基础上加了一组paragraph输入列向量一起训练构成的模型
public class ParagraphVectorsClassifierExample { ParagraphVectors paragraphVectors ;//声明 ParagraphVectors类
LabelAwareIterator iterator ;//声明要实现的迭代器接口,用来识别句子或文档及标签,这里假定所有的文档已变成字符串或词表的形式 TokenizerFactory tokenizerFactory ;//声明字符串分割器 private static final Logger log = LoggerFactory. getLogger (ParagraphVectorsClassifierExample. class ) ; public static void main (String[] args) throws Exception { ParagraphVectorsClassifierExample app = new ParagraphVectorsClassifierExample() ;//又是这种写法,构建实现类 app.makeParagraphVectors() ;//调用构建模型方法 app.checkUnlabeledData() ;//检查标签数据 /* Your output should be like this: Document 'health' falls into the following categories: health: 0.29721372296220205 science: 0.011684473733853906 finance: -0.14755302887323793 Document 'finance' falls into the following categories: health: -0.17290237675941766 science: -0.09579267574606627 finance: 0.4460859189453788 so,now we know categories for yet unseen documents */ } void makeParagraphVectors () throws Exception { ClassPathResource resource = new ClassPathResource( "paravec/labeled" ) ;//弄一个带标签的文档路径 // build a iterator for our dataset iterator = new FileLabelAwareIterator.Builder()//实现 LabelAwareIterator接口,添加数据源,构成迭代器 .addSourceFolder(resource.getFile()) .build() ; tokenizerFactory = new DefaultTokenizerFactory() ;//构建逗号分割器 tokenizerFactory .setTokenPreProcessor( new CommonPreprocessor()) ; // ParagraphVectors training configuration paragraphVectors = new ParagraphVectors.Builder()// ParagraphVectors继承 Word2Vec, Word2Vec继承 SequenceVectors,
配置ParagraphVectors的学习率,最小学习率,批大小,步数,迭代器,同时构建词和文档,词分割器
.learningRate( 0.025 ) .minLearningRate( 0.001 ) .batchSize( 1000 ) .epochs( 20 ) .iterate( iterator ) .trainWordVectors( true ) .tokenizerFactory( tokenizerFactory ) .build() ; // Start model training paragraphVectors .fit() ;//模型定型 } void checkUnlabeledData () throws FileNotFoundException { /* At this point we assume that we have model built and we can check//这里假定模型已经构建好,现在预测无标签的文档属于哪个类,我们装载无标签文档并对其进行检测 which categories our unlabeled document falls into. So we'll start loading our unlabeled documents and checking them */ ClassPathResource unClassifiedResource = new ClassPathResource( "paravec/unlabeled" ) ;//构建无标签文档读取器 FileLabelAwareIterator unClassifiedIterator = new FileLabelAwareIterator.Builder() .addSourceFolder(unClassifiedResource.getFile()) .build() ; /* Now we'll iterate over unlabeled data, and check which label it could be assigned to//预测未标记文档,很多情况一个文档可能对应多个类别,只不过每个类别值有高有低 Please note: for many domains it's normal to have 1 document fall into few labels at once, with different "weight" for each. */ MeansBuilder meansBuilder = new MeansBuilder(//构建了求质心的类, (InMemoryLookupTable) paragraphVectors .getLookupTable() ,//通过获取 WordVectors实现类 WordVectorsImpl中的 getLookupTable方法获取查询table及 tokenizerFactory构造 MeansBuilder类
tokenizerFactory ) ; LabelSeeker seeker = new LabelSeeker( iterator .getLabelsSource().getLabels() ,//同理通过获取 WordVectors实现类 WordVectorsImpl中的 getLookupTable方法获取查询table及标签列表构造LabelSeeker类 (InMemoryLookupTable) paragraphVectors .getLookupTable()) ; while (unClassifiedIterator.hasNextDocument()) {//遍历未分类文档 LabelledDocument document = unClassifiedIterator.nextDocument() ; INDArray documentAsCentroid = meansBuilder.documentAsVector(document) ;//把文档转成向量 List , Double>> scores = seeker.getScores(documentAsCentroid) ;//获取文档的类别得分 /* please note, document.getLabel() is used just to show which document we're looking at now, as a substitute for printing out the whole document name. So, labels on these two documents are used like titles, just to visualize our classification done properly//注意getLabel是获取当前文档的标签 */ log .info( "Document '" + document.getLabel() + "' falls into the following categories: " ) ; for (Pair , Double> score: scores) {//遍历标签得分 log .info( " " + score.getFirst() + ": " + score.getSecond()) ;//打印元素的第一个第二个元素 } } }
public class MeansBuilder {//平均值类
    private VocabCache vocabCache;//词汇表
    private InMemoryLookupTable lookupTable;//查询table
    private TokenizerFactory tokenizerFactory;//分词器

    public MeansBuilder(@NonNull InMemoryLookupTable lookupTable,//构造方法,根据传入的参数赋值当前对象的词汇表,查询table,分词器
        @NonNull TokenizerFactory tokenizerFactory) {
        this.lookupTable = lookupTable;
        this.vocabCache = lookupTable.getVocab();
        this.tokenizerFactory = tokenizerFactory;
    }

    /**
     * This method returns centroid (mean vector) for document.//返回文档的质心,也就是向量的平均值
     *
     * @param document
     * @return
     */
    public INDArray documentAsVector(@NonNull LabelledDocument document) {//传入有标记的文档
        List documentAsTokens = tokenizerFactory.create(document.getContent()).getTokens();//切割文档,获取词列表
        AtomicInteger cnt = new AtomicInteger(0);//声明一个原子整数0,保证线程安全
        for (String word: documentAsTokens) {//统计独立词计数
            if (vocabCache.containsWord(word)) cnt.incrementAndGet();
        }
        INDArray allWords = Nd4j.create(cnt.get(), lookupTable.layerSize());//根据词计数构建词矩阵,行是词计数,列是每个词对应的向量长度,默认100
        cnt.set(0);//词计数清零
        for (String word: documentAsTokens) {//给词矩阵赋值,
            if (vocabCache.containsWord(word))
                allWords.putRow(cnt.getAndIncrement(), lookupTable.vector(word));//根据词表索引,取出对应词权重向量的行,放入allWords矩阵
        }

        INDArray mean = allWords.mean(0);//通过mean(0)把矩阵合成一行,0代表维度,也是就求质心并返回

        return mean;
    }
}
 
  
 
  
public class LabelSeeker {//寻找标签类
    private List labelsUsed;//声明标签列表
    private InMemoryLookupTable lookupTable;//声明查询table

    public LabelSeeker(@NonNull List labelsUsed, @NonNull InMemoryLookupTable lookupTable) {//构造器
        if (labelsUsed.isEmpty()) throw new IllegalStateException("You can't have 0 labels used for ParagraphVectors");
        this.lookupTable = lookupTable;
        this.labelsUsed = labelsUsed;
    }

    /**
     * This method accepts vector, that represents any document,//这方法接收表示文档的向量,返回文档的距离,之前训练的类别
     * and returns distances between this document, and previously trained categories
     * @return
     */
    public List, Double>> getScores(@NonNull INDArray vector) {//获取得分的方法
        List, Double>> result = new ArrayList<>();//声明列表,每个元素都是元素
        for (String label: labelsUsed) {遍历标签列表
            INDArray vecLabel = lookupTable.vector(label);//同理根据词表索引,取出对应词权重向量的行

            if (vecLabel == null) throw new IllegalStateException("Label '"+ label+"' has no known vector!");

            double sim = Transforms.cosineSim(vector, vecLabel);//把词权重向量和传入的文档做相似度
            result.add(new Pair, Double>(label, sim));//返回和每个标签的相似度
        }
        return result;
    }
}

}

你可能感兴趣的:(深度学习,deeplearning4j)