doc2vec用于分类 deeplearning4j实现

1.简概

  上一篇简单介绍doc2vec的实现以及原理,这一篇看看用doc2vec用于文本分类情况。

2.数据格式

跟cnn、lstm输入格式一样
      
      
      
      
  1. 1 看头像 加微信
  2. 1 专业 办理 二手房交易 公积金贷款 商业贷款 租房 需要 请来 店咨询
  3. 1 奥森 健身 砸金蛋 捂脸 砸金蛋 不是 李咏 专利 活动 当天 人人 都是 李咏 奥森 健身 小季 特邀 6 18 咱们 砸金蛋 百分百 中奖 百分百 有礼 只有 想不到 没有 我们 不到 优惠力度 我们 老总 知道 害怕 偷偷 告诉 捂脸 咨询电话 手机号码
  4. 1 建议 下次 麻花 另外 袋子 起来 盒子 太软 长肉 美女 座机号码
  5. 0 两张 情侣卡 还有 10 3 k
  6. 0 专业 甲醛检测 甲醛 治理 清除 装修 异味 健康 呼吸 保驾护航
  7. 1 不错 儿子 喜欢吃 宝妈 照顾好 家的 同时 月入 上万 加我 其他数字
  8. 1 可以 微信同号


3.代码实现


实现是实现 LabelAwareIterator接口,看看实现的情况
      
      
      
      
  1. package com.dianping.deeplearning.paragraphvectors;
  2. import java.io.BufferedReader;
  3. import java.io.FileInputStream;
  4. import java.io.InputStreamReader;
  5. import java.util.ArrayList;
  6. import java.util.Collections;
  7. import java.util.HashMap;
  8. import java.util.List;
  9. import java.util.Map;
  10. import java.util.Random;
  11. import org.datavec.api.util.RandomUtils;
  12. import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
  13. import org.deeplearning4j.text.documentiterator.LabelledDocument;
  14. import org.deeplearning4j.text.documentiterator.LabelsSource;
  15. import org.nd4j.linalg.collection.CompactHeapStringList;
  16. public class TxtLabelAwareIterator implements LabelAwareIterator {
  17. private int totalCount;
  18. private Map<String, List<String>> filesByLabel;
  19. private List<String> normList;
  20. private List<String> negList;
  21. private final List<String> sentenslist;
  22. private final int[] labelIndexes;
  23. private final Random rng;
  24. private final int[] order;
  25. private final List<String> allLabels;
  26. private LabelsSource source;
  27. private int cursor = 0;
  28. public TxtLabelAwareIterator(String path) {
  29. this(path, new Random());
  30. }
  31. public TxtLabelAwareIterator(String path, Random rng) {
  32. totalCount = 0;
  33. filesByLabel = new HashMap<String, List<String>>();
  34. normList = new ArrayList<String>();
  35. negList = new ArrayList<>();
  36. BufferedReader buffered = null;
  37. try {
  38. buffered = new BufferedReader(new InputStreamReader(
  39. new FileInputStream(path)));
  40. String line = buffered.readLine();
  41. while (line != null) {
  42. String[] lines = line.split("\t");
  43. String label = lines[0];
  44. String contennt = lines[1];
  45. if ("1".equalsIgnoreCase(label)) {
  46. normList.add(contennt);
  47. } else if ("0".equalsIgnoreCase(label)) {
  48. negList.add(contennt);
  49. }
  50. totalCount++;
  51. line = buffered.readLine();
  52. }
  53. buffered.close();
  54. } catch (Exception e) {
  55. e.printStackTrace();
  56. }
  57. filesByLabel.put("1", normList);
  58. filesByLabel.put("0", negList);
  59. this.rng = rng;
  60. if (rng == null) {
  61. order = null;
  62. } else {
  63. order = new int[totalCount];
  64. for (int i = 0; i < totalCount; i++) {
  65. order[i] = i;
  66. }
  67. RandomUtils.shuffleInPlace(order, rng);
  68. }
  69. allLabels = new ArrayList<>(filesByLabel.keySet());
  70. source = new LabelsSource(allLabels);
  71. Collections.sort(allLabels);
  72. Map<String, Integer> labelsToIdx = new HashMap<>();
  73. for (int i = 0; i < allLabels.size(); i++) {
  74. labelsToIdx.put(allLabels.get(i), i);
  75. }
  76. sentenslist = new CompactHeapStringList();
  77. labelIndexes = new int[totalCount];
  78. int position = 0;
  79. for (Map.Entry<String, List<String>> entry : filesByLabel.entrySet()) {
  80. int labelIdx = labelsToIdx.get(entry.getKey());
  81. for (String f : entry.getValue()) {
  82. sentenslist.add(f);
  83. labelIndexes[position] = labelIdx;
  84. position++;
  85. }
  86. }
  87. }
  88. @Override
  89. public boolean hasNext() {
  90. return cursor < totalCount;
  91. }
  92. @Override
  93. public LabelledDocument next() {
  94. return nextDocument();
  95. }
  96. @Override
  97. public boolean hasNextDocument() {
  98. return hasNextDocument();
  99. }
  100. @Override
  101. public LabelledDocument nextDocument() {
  102. LabelledDocument document = new LabelledDocument();
  103. int idx;
  104. if (rng == null) {
  105. idx = cursor++;
  106. } else {
  107. idx = order[cursor++];
  108. }
  109. ;
  110. String label = allLabels.get(labelIndexes[idx]);
  111. String sentence;
  112. sentence = sentenslist.get(idx);
  113. document.setContent(sentence);
  114. document.addLabel(label);
  115. return document;
  116. }
  117. @Override
  118. public void reset() {
  119. cursor = 0;
  120. if (rng != null) {
  121. RandomUtils.shuffleInPlace(order, rng);
  122. }
  123. }
  124. @Override
  125. public LabelsSource getLabelsSource() {
  126. return source;
  127. }
  128. @Override
  129. public void shutdown() {
  130. }
  131. }

分类器实现:
       
       
       
       
  1. package com.dianping.deeplearning.paragraphvectors;
  2. import java.util.List;
  3. import org.deeplearning4j.berkeley.Pair;
  4. import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
  5. import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
  6. import org.deeplearning4j.models.word2vec.VocabWord;
  7. import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
  8. import org.deeplearning4j.text.documentiterator.LabelledDocument;
  9. import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
  10. import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
  11. import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
  12. import org.nd4j.linalg.api.ndarray.INDArray;
  13. public class Doc2VecAdxClassify {
  14. private String path = "adx/rnnsenec.txt";
  15. ParagraphVectors paragraphVectors;
  16. LabelAwareIterator iterator;
  17. TokenizerFactory tokenizerFactory;
  18. public static void main(String[] args) {
  19. Doc2VecAdxClassify doc2vec = new Doc2VecAdxClassify();
  20. doc2vec.makeParagraphVectors();
  21. // 预测分类
  22. System.out.println(doc2vec.paragraphVectors
  23. .predict("专业 甲醛检测 甲醛 治理 清除 装修 异味 给 您 健康 呼吸"));
  24. MeansBuilder meansBuilder = new MeansBuilder(
  25. (InMemoryLookupTable<VocabWord>) doc2vec.paragraphVectors
  26. .getLookupTable(),
  27. doc2vec.tokenizerFactory);
  28. LabelSeeker seeker = new LabelSeeker(doc2vec.iterator.getLabelsSource()
  29. .getLabels(),
  30. (InMemoryLookupTable<VocabWord>) doc2vec.paragraphVectors
  31. .getLookupTable());
  32. LabelledDocument document = new LabelledDocument();
  33. document.setContent("专业 甲醛检测 甲醛 治理 清除 装修 异味 给 您 健康 呼吸");
  34. document.addLabel("0");
  35. meansBuilder.documentAsVector(document);
  36. INDArray documentAsCentroid = meansBuilder.documentAsVector(document);
  37. List<Pair<String, Double>> scores = seeker
  38. .getScores(documentAsCentroid);
  39. for (Pair<String, Double> score : scores) {
  40. System.out.println(" " + score.getFirst() + ": "+ score.getSecond());
  41. }
  42. }
  43. public void makeParagraphVectors() {
  44. System.out.println("path is :" + path);
  45. iterator = new TxtLabelAwareIterator(path);
  46. System.out.println(iterator.getLabelsSource().getLabels());
  47. tokenizerFactory = new DefaultTokenizerFactory();
  48. tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
  49. paragraphVectors = new ParagraphVectors.Builder()
  50. .learningRate(0.025)
  51. .minLearningRate(0.001).
  52. batchSize(1000).
  53. epochs(20)
  54. .iterate(iterator)
  55. .trainWordVectors(true)
  56. .tokenizerFactory(tokenizerFactory)
  57. .build();
  58. // Start model training
  59. paragraphVectors.fit();
  60. }
  61. }

4.结果
       
       
       
       
  1. 1
  2. 0: -0.2978013753890991
  3. 1: 0.17002613842487335

最后输出的是输出文本与标签之间的余弦相似度   ,分类还是较为准确的。

你可能感兴趣的:(机器学习,java)