lstm用于文本分类(gpu实现)--deeplearning4j为例子

先看显卡:

    
    
    
    
  1. root@image-ubuntu:/home/zhoumeixu/credit-textclassify-deeplearning# nvidia-smi
  2. Fri Jul 14 08:00:24 2017
  3. +-----------------------------------------------------------------------------+
  4. | NVIDIA-SMI 375.51 Driver Version: 375.51 |
  5. |-------------------------------+----------------------+----------------------+
  6. | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
  7. | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
  8. |===============================+======================+======================|
  9. | 0 Tesla M60 Off | 0000:00:02.0 Off | Off |
  10. | N/A 44C P0 39W / 150W | 2452MiB / 8123MiB | 17% Default |
  11. +-------------------------------+----------------------+----------------------+
  12. +-----------------------------------------------------------------------------+
  13. | Processes: GPU Memory |
  14. | GPU PID Type Process name Usage |
  15. |=============================================================================|
  16. | 0 76510 C java 2444MiB |
  17. +-----------------------------------------------------------------------------+


数据格式:

     
     
     
     
  1. 1 预定 h 其他数字 免费
  2. 1 属于 派对 不是 特别 商务 就是 位置 是个 非常 麻烦 基本 一天 都是 人满为患 是个 放松 好地方 过来 可以 找我 预定 联系方式 名字
  3. 1 长宁 v show 很多 亮点 可能 清楚 1 不收 一分 计时 房费 2 生日 免费 布置 生日蛋糕 7 22 活动内容 3 专业 接待 经理 全程 接待 4 周一 会员卡 折优惠 5 房间 各种 主题 可供 选择 6 周五 12 以后 消减 一半 7 5 星级 厨师 推出 99 小吃 8 商务区 房间 享受 管家 服务 9 评出 场所 5 优点 鸡尾酒 10 大厅 走道 设有 拍照 区域
  4. 1 企鹅 0 死死 一六 0六三


DataSetIterator实现:

    
    
    
    
  1. package com.dianping.deeplearning.rnn;
  2. import java.io.BufferedReader;
  3. import java.io.FileInputStream;
  4. import java.io.IOException;
  5. import java.io.InputStreamReader;
  6. import java.util.ArrayList;
  7. import java.util.Arrays;
  8. import java.util.Collections;
  9. import java.util.List;
  10. import java.util.NoSuchElementException;
  11. import java.util.Random;
  12. import org.apache.commons.lang.StringUtils;
  13. import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
  14. import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
  15. import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
  16. import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
  17. import org.nd4j.linalg.api.ndarray.INDArray;
  18. import org.nd4j.linalg.dataset.DataSet;
  19. import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
  20. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  21. import org.nd4j.linalg.factory.Nd4j;
  22. import org.nd4j.linalg.indexing.INDArrayIndex;
  23. import org.nd4j.linalg.indexing.NDArrayIndex;
  24. public class SentimentIterator implements DataSetIterator {
  25. private final WordVectors wordVectors;
  26. private final int batSize;
  27. private final int truncateLength;
  28. private final int vectorSize;
  29. private int cursor = 0;
  30. private List<String> positiveList=new ArrayList<>();
  31. private List<String> negativeList=new ArrayList<>();
  32. private final TokenizerFactory tokenizerFactory;
  33. public SentimentIterator(String path, WordVectors wordVectors,
  34. int batchSize, int truncateLength) {
  35. BufferedReader bufferedReader = null;
  36. try {
  37. bufferedReader = new BufferedReader(new InputStreamReader(
  38. new FileInputStream(path)));
  39. String line = bufferedReader.readLine();
  40. while (line != null) {
  41. String[] lines = line.split("\t");
  42. if (lines.length > 1) {
  43. String label = lines[0];
  44. String content = lines[1];
  45. if (StringUtils.isNotBlank(content)) {
  46. if ("1".equalsIgnoreCase(label)) {
  47. positiveList.add(content);
  48. } else if ("0".equalsIgnoreCase(label)) {
  49. negativeList.add(content);
  50. }
  51. }
  52. }
  53. line = bufferedReader.readLine();
  54. }
  55. bufferedReader.close();
  56. } catch (Exception e) {
  57. e.printStackTrace();
  58. }
  59. this.batSize = batchSize;
  60. this.wordVectors = wordVectors;
  61. this.vectorSize = wordVectors.getWordVector(wordVectors.vocab()
  62. .wordAtIndex(0)).length;
  63. this.truncateLength = truncateLength;
  64. tokenizerFactory = new DefaultTokenizerFactory();
  65. tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
  66. }
  67. private DataSet nextDataSet(int num) throws IOException {
  68. List<String> reviews = new ArrayList<>(num);
  69. boolean[] positive = new boolean[num];
  70. for (int i = 0; i < num && cursor < totalExamples(); i++) {
  71. if (cursor % 2 == 0) {
  72. int posReviewNumber = cursor / 2;
  73. String review = null;
  74. if (posReviewNumber < positiveList.size() - 1) {
  75. review = positiveList.get(posReviewNumber);
  76. } else {
  77. Random randn = new Random();
  78. int randint = randn.nextInt(positiveList.size());
  79. review = positiveList.get(randint);
  80. }
  81. positive[i] = true;
  82. reviews.add(review);
  83. } else {
  84. int negReviewNumber = cursor / 2;
  85. String review = null;
  86. if (negReviewNumber < negativeList.size() - 1) {
  87. review = negativeList.get(negReviewNumber);
  88. } else {
  89. Random randn = new Random();
  90. int randint = randn.nextInt(negativeList.size());
  91. review = negativeList.get(randint);
  92. }
  93. reviews.add(review);
  94. positive[i] = false;
  95. }
  96. cursor++;
  97. }
  98. List<List<String>> allTokens = new ArrayList<>(reviews.size());
  99. int maxLength = 0;
  100. for (String s : reviews) {
  101. List<String> tokens = tokenizerFactory.create(s).getTokens();
  102. List<String> tokensFiltered = new ArrayList<>();
  103. for (String t : tokens) {
  104. if (wordVectors.hasWord(t))
  105. tokensFiltered.add(t);
  106. }
  107. allTokens.add(tokensFiltered);
  108. maxLength = Math.max(maxLength, tokensFiltered.size());
  109. }
  110. if (maxLength > truncateLength)
  111. maxLength = truncateLength;
  112. INDArray features = Nd4j.create(reviews.size(), vectorSize, maxLength);
  113. INDArray labels = Nd4j.create(reviews.size(), 2, maxLength); // Two
  114. // labels:
  115. // positive
  116. // or
  117. // negative
  118. // Because we are dealing with reviews of different lengths and only one
  119. // output at the final time step: use padding arrays
  120. // Mask arrays contain 1 if data is present at that time step for that
  121. // example, or 0 if data is just padding
  122. INDArray featuresMask = Nd4j.zeros(reviews.size(), maxLength);
  123. INDArray labelsMask = Nd4j.zeros(reviews.size(), maxLength);
  124. int[] temp = new int[2];
  125. for (int i = 0; i < reviews.size(); i++) {
  126. List<String> tokens = allTokens.get(i);
  127. temp[0] = i;
  128. // Get word vectors for each word in review, and put them in the
  129. // training data
  130. for (int j = 0; j < tokens.size() && j < maxLength; j++) {
  131. String token = tokens.get(j);
  132. INDArray vector = wordVectors.getWordVectorMatrix(token);
  133. features.put(new INDArrayIndex[] { NDArrayIndex.point(i),
  134. NDArrayIndex.all(), NDArrayIndex.point(j) }, vector);
  135. temp[1] = j;
  136. featuresMask.putScalar(temp, 1.0); // Word is present (not
  137. // padding) for this example
  138. // + time step -> 1.0 in
  139. // features mask
  140. }
  141. int idx = (positive[i] ? 0 : 1);
  142. int lastIdx = Math.min(tokens.size(), maxLength);
  143. labels.putScalar(new int[] { i, idx, lastIdx - 1 }, 1.0); // Set
  144. // label:
  145. // [0,1]
  146. // for
  147. // negative,
  148. // [1,0]
  149. // for
  150. // positive
  151. labelsMask.putScalar(new int[] { i, lastIdx - 1 }, 1.0); // Specify
  152. // that
  153. // an
  154. // output
  155. // exists
  156. // at
  157. // the
  158. // final
  159. // time
  160. // step
  161. // for
  162. // this
  163. // example
  164. }
  165. return new DataSet(features, labels, featuresMask, labelsMask);
  166. }
  167. @Override
  168. public DataSet next(int num) {
  169. if (cursor >= negativeList.size() + positiveList.size())
  170. throw new NoSuchElementException();
  171. try {
  172. return nextDataSet(num);
  173. } catch (IOException e) {
  174. throw new RuntimeException(e);
  175. }
  176. }
  177. @Override
  178. public boolean hasNext() {
  179. return cursor < numExamples();
  180. }
  181. @Override
  182. public DataSet next() {
  183. return next(batSize);
  184. }
  185. @Override
  186. public void remove() {
  187. }
  188. @Override
  189. public int totalExamples() {
  190. return positiveList.size() + negativeList.size();
  191. }
  192. @Override
  193. public int inputColumns() {
  194. return vectorSize;
  195. }
  196. @Override
  197. public int totalOutcomes() {
  198. return 2;
  199. }
  200. @Override
  201. public boolean resetSupported() {
  202. return true;
  203. }
  204. @Override
  205. public boolean asyncSupported() {
  206. return true;
  207. }
  208. @Override
  209. public void reset() {
  210. cursor = 0;
  211. Collections.shuffle(negativeList);
  212. Collections.shuffle(positiveList);
  213. }
  214. @Override
  215. public int batch() {
  216. return batSize;
  217. }
  218. @Override
  219. public int cursor() {
  220. return cursor;
  221. }
  222. @Override
  223. public int numExamples() {
  224. return totalExamples();
  225. }
  226. @Override
  227. public void setPreProcessor(DataSetPreProcessor preProcessor) {
  228. }
  229. @Override
  230. public DataSetPreProcessor getPreProcessor() {
  231. throw new UnsupportedOperationException("Not implemented");
  232. }
  233. @Override
  234. public List<String> getLabels() {
  235. return Arrays.asList("1", "0");
  236. }
  237. }


主程序入口:

     
     
     
     
  1. package com.dianping.deeplearning.rnn;
  2. import org.deeplearning4j.eval.Evaluation;
  3. import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
  4. import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
  5. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  6. import org.deeplearning4j.nn.conf.GradientNormalization;
  7. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  8. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  9. import org.deeplearning4j.nn.conf.Updater;
  10. import org.deeplearning4j.nn.conf.layers.GravesLSTM;
  11. import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
  12. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  13. import org.deeplearning4j.nn.weights.WeightInit;
  14. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  15. import org.deeplearning4j.parallelism.ParallelWrapper;
  16. import org.nd4j.jita.conf.CudaEnvironment;
  17. import org.nd4j.linalg.activations.Activation;
  18. import org.nd4j.linalg.api.buffer.DataBuffer;
  19. import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
  20. import org.nd4j.linalg.api.ndarray.INDArray;
  21. import org.nd4j.linalg.dataset.DataSet;
  22. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  23. import org.nd4j.linalg.factory.Nd4j;
  24. import org.nd4j.linalg.lossfunctions.LossFunctions;
  25. public class TrainAdxRnnModelWithGPU {
  26. private static String basepath = "/home/zhoumeixu/model/";
  27. //private static String basepath = "adx/";
  28. public static void main(String[] args) {
  29. int batchSize = 48; // Number of examples in each minibatch
  30. int nEpochs = 35; // 训练次数
  31. int truncateLength = 256; // 文本最大长度
  32. WordVectors wordVectors = WordVectorSerializer.readWord2VecModel(basepath + "word2vec.model");
  33. DataSetIterator train = getDataSetIterator(basepath + "rnnsenec.txt",wordVectors, batchSize, truncateLength);
  34. DataSetIterator test = getDataSetIterator(basepath + "rnnsenectest.txt", wordVectors, batchSize,truncateLength);
  35. System.out.println("。。。。。。。gpu初始化即将开始。。。。。。。。。");
  36. DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
  37. CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true)
  38. .setMaximumDeviceCache(2L * 1024L * 1024L * 1024L).allowCrossDeviceAccess(true);
  39. System.out.println("。。。。。。。。。gpu初始化即将结束。。。。。。。。。。");
  40. int outputs = train.getLabels().size();
  41. int inputNeurons = wordVectors.getWordVector(wordVectors.vocab()
  42. .wordAtIndex(0)).length; // 15 in our case
  43. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  44. .optimizationAlgo(
  45. OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  46. .iterations(1)
  47. .updater(Updater.RMSPROP)
  48. .regularization(true)
  49. .l2(1e-5)
  50. .weightInit(WeightInit.XAVIER)
  51. .gradientNormalization(
  52. GradientNormalization.ClipElementWiseAbsoluteValue)
  53. .gradientNormalizationThreshold(1.0)
  54. .learningRate(0.0018)
  55. .list()
  56. .layer(0,
  57. new GravesLSTM.Builder().nIn(inputNeurons).nOut(200)
  58. .activation(Activation.SOFTSIGN).build())
  59. .layer(1,
  60. new RnnOutputLayer.Builder()
  61. .activation(Activation.SOFTMAX)
  62. .lossFunction(LossFunctions.LossFunction.MCXENT)
  63. .nIn(200).nOut(outputs).build())
  64. .pretrain(false).backprop(true).build();
  65. MultiLayerNetwork net = new MultiLayerNetwork(conf);
  66. net.init();
  67. /* ParallelWrapper pw = new ParallelWrapper.Builder<>(net)
  68. .prefetchBuffer(16 * Nd4j.getAffinityManager().getNumberOfDevices())
  69. .reportScoreAfterAveraging(true)
  70. .averagingFrequency(10)
  71. .useLegacyAveraging(false)
  72. .useMQ(true)
  73. .workers(Nd4j.getAffinityManager().getNumberOfDevices())
  74. .build();
  75. */
  76. ParallelWrapper pw = new ParallelWrapper.Builder(net)
  77. .prefetchBuffer(24)
  78. .workers(4)
  79. .averagingFrequency(3)
  80. .reportScoreAfterAveraging(true)
  81. .useLegacyAveraging(true)
  82. .build();
  83. // 设置没两百步观察数据情况
  84. net.setListeners(new ScoreIterationListener(200));
  85. System.out.println("Starting training");
  86. for (int i = 0; i < nEpochs; i++) {
  87. pw.fit(train);
  88. train.reset();
  89. System.out.println("Epoch " + i + " complete. Starting evaluation:");
  90. Evaluation evaluation = new Evaluation();
  91. while (test.hasNext()) {
  92. DataSet t = test.next();
  93. INDArray features = t.getFeatureMatrix();
  94. INDArray lables = t.getLabels();
  95. // System.out.println("labels : " + lables);
  96. INDArray inMask = t.getFeaturesMaskArray();
  97. INDArray outMask = t.getLabelsMaskArray();
  98. INDArray predicted = net.output(features, false);
  99. // System.out.println("predicted : " + predicted);
  100. evaluation.evalTimeSeries(lables, predicted, outMask);
  101. }
  102. test.reset();
  103. System.out.println(evaluation.stats());
  104. }
  105. }
  106. public static DataSetIterator getDataSetIterator(String path,
  107. WordVectors wordVectors, int batchSize, int truncateLength) {
  108. DataSetIterator dataSetIterator = new SentimentIterator(path,
  109. wordVectors, batchSize, truncateLength);
  110. return dataSetIterator;
  111. }
  112. }



运行过程:

     
     
     
     
  1. 07:58:12.231 [ParallelWrapper trainer 3] DEBUG o.n.j.c.CudaAffinityManager - Mapping thread [1869] to device [0], out of [1] devices...
  2. 07:58:12.232 [ParallelWrapper trainer 0] DEBUG o.n.j.c.CudaAffinityManager - Mapping thread [1866] to device [0], out of [1] devices...
  3. 07:58:12.910 [ParallelWrapper trainer 0] INFO org.nd4j.nativeblas.Nd4jBlas - Number of threads used for BLAS: 0
  4. 07:58:16.646 [ParallelWrapper trainer 0] INFO o.d.o.l.ScoreIterationListener - Score at iteration 0 is 0.6893714558848892
  5. 07:58:27.803 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6904921923723948
  6. 07:58:37.580 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6918042188376753
  7. 07:58:47.585 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6902696776475821
  8. 07:58:58.407 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6904408029170671
  9. 07:59:09.717 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6896152988017
  10. 07:59:22.807 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6888589759877908
  11. 07:59:40.347 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6892489896929257
  12. 08:00:00.643 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6895330264029644
  13. 08:00:21.219 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6878439507341991
  14. 08:00:45.554 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6887204976666248
  15. 08:01:07.273 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6873620491523212
  16. 08:01:29.412 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6879257711235788
  17. 08:01:50.624 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6893160108520883
  18. 08:02:16.107 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: 0.6871896758883935


你可能感兴趣的:(机器学习,R语言)