前面介绍用卷积训练文本分类模型,但是算法是cpu上跑的,涉及到大数据,cpu上是跑不动的,代码在之前的博客里面可以看到,本博客主要记录在gpu上跑碰到的坑。
root@image-ubuntu:~# nvidia-smi
Fri Jul 14 01:21:46 2017
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 375.51 Driver Version: 375.51 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Tesla M60 Off | 0000:00:02.0 Off | Off |
| N/A 52C P0 46W / 150W | 3448MiB / 8123MiB | 21% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 53395 C java 3438MiB |
+-----------------------------------------------------------------------------+
Exception in thread "main" java.lang.UnsupportedClassVersionError: org/deeplearning4j/parallelism/ParallelWrapper$Builder : Unsupported major.minor version 52.0
at java.lang.ClassLoader.defineClass1(Native Method)
at java.lang.ClassLoader.defineClass(ClassLoader.java:800)
at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142)
at java.net.URLClassLoader.defineClass(URLClassLoader.java:449)
at java.net.URLClassLoader.access$100(URLClassLoader.java:71)
at java.net.URLClassLoader$1.run(URLClassLoader.java:361)
at java.net.URLClassLoader$1.run(URLClassLoader.java:355)
at java.security.AccessController.doPrivileged(Native Method)
at java.net.URLClassLoader.findClass(URLClassLoader.java:354)
at java.lang.ClassLoader.loadClass(ClassLoader.java:425)
at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:308)
at java.lang.ClassLoader.loadClass(ClassLoader.java:358)
at com.dianping.deeplearning.test.TestWithGPU.main(TestWithGPU.java:115)
INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
03:57:22.643 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
03:57:28.993 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
03:57:35.097 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
Exception in thread "main" java.lang.RuntimeException: Exception thrown in base iterator
at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator.next(AsyncDataSetIterator.java:247)
at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator.next(AsyncDataSetIterator.java:33)
at org.deeplearning4j.parallelism.ParallelWrapper.fit(ParallelWrapper.java:379)
at com.dianping.deeplearning.cnn.TrainAdxCnnModelWithGPU.main(TrainAdxCnnModelWithGPU.java:170)
Caused by: org.nd4j.linalg.exception.ND4JIllegalStateException: Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4776)
at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3997)
at org.nd4j.linalg.api.ndarray.BaseNDArray.create(BaseNDArray.java:1906)
at org.nd4j.linalg.api.ndarray.BaseNDArray.subArray(BaseNDArray.java:2064)
at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4015)
at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:222)
at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:155)
at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:25)
at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator$IteratorRunnable.run(AsyncDataSetIterator.java:322)
最后附上训练gpu的代码:
package com.dianping.deeplearning.cnn;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class TrainAdxCnnModelWithGPU {
public static void main(String[] args) throws FileNotFoundException,
UnsupportedEncodingException {
/*
* gpu训练设置
*/
System.out.println("。。。。。。。gpu初始化即将开始。。。。。。。。。");
// PLEASE NOTE: For CUDA FP16 precision support is available
DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
// temp workaround for backend initialization
CudaEnvironment.getInstance().getConfiguration()
// key option enabled
.allowMultiGPU(true)
// we're allowing larger memory caches
.setMaximumDeviceCache(2L * 1024L * 1024L * 1024L)
// cross-device access is used for faster model averaging over pcie
.allowCrossDeviceAccess(true);
System.out.println("。。。。。。。。。gpu初始化即将结束。。。。。。。。。。");
String WORD_VECTORS_PATH = "/home/zhoumeixu/model/word2vec.model";
// 基础配置
int batchSize = 128;
int vectorSize = 15; // 词典向量的维度,这边是100
int nEpochs = 15000; // 重复多少次
int iterator = 1;// 迭代多少次
int truncateReviewsToLength = 256; // 词长大于256则抛弃
int cnnLayerFeatureMaps = 100; // 卷积神经网络特征图标 / channels / CNN每层layer的深度
PoolingType globalPoolingType = PoolingType.MAX;
Random rng = new Random(100); // 随机抽样
// 设置网络配置->我们有多个卷积层,每个带宽3,4,5的滤波器
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.RELU)
.activation(Activation.LEAKYRELU)
.updater(Updater.ADAM)
.convolutionMode(ConvolutionMode.Same)
// This is important so we can 'stack' the results later
.regularization(true)
.l2(0.0001)
.iterations(iterator)
.learningRate(0.01)
.graphBuilder()
.addInputs("input")
.addLayer(
"cnn3",
new ConvolutionLayer.Builder()
.kernelSize(3, vectorSize)
.stride(1, vectorSize).nIn(1)
.nOut(cnnLayerFeatureMaps).build(), "input")
.addLayer(
"cnn4",
new ConvolutionLayer.Builder()
.kernelSize(4, vectorSize)
.stride(1, vectorSize).nIn(1)
.nOut(cnnLayerFeatureMaps).build(), "input")
.addLayer(
"cnn5",
new ConvolutionLayer.Builder()
.kernelSize(5, vectorSize)
.stride(1, vectorSize).nIn(1)
.nOut(cnnLayerFeatureMaps).build(), "input")
.addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5")
// Perform depth concatenation
.addLayer(
"globalPool",
new GlobalPoolingLayer.Builder().poolingType(
globalPoolingType).build(), "merge")
.addLayer(
"out",
new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(3 * cnnLayerFeatureMaps).nOut(2).build(),
"globalPool").setOutputs("out").build();
ComputationGraph net = new ComputationGraph(config);
net.init();
// ParallelWrapper will take care of load balancing between GPUs.
ParallelWrapper wrapper = new ParallelWrapper.Builder(net)
// DataSets prefetching options. Set this value with respect to number of actual devices
.prefetchBuffer(24)
// set number of workers equal or higher then number of available devices. x1-x2 are good values to start with
.workers(4)
// rare averaging improves performance, but might reduce model accuracy
.averagingFrequency(3)
// if set to TRUE, on every averaging model score will be reported
.reportScoreAfterAveraging(true)
// optinal parameter, set to false ONLY if your system has support P2P memory access across PCIe (hint: AWS do not support P2P)
.useLegacyAveraging(true)
.build();
net.setListeners(new ScoreIterationListener(100));
// 加载向量字典并获取训练集合测试集的DataSetIterators
System.out
.println("Loading word vectors and creating DataSetIterators");
/*
* WordVectors wordVectors = WordVectorSerializer
* .fromPair(WordVectorSerializer.loadTxt(new File(
* WORD_VECTORS_PATH)));
*/
WordVectors wordVectors = WordVectorSerializer
.readWord2VecModel(WORD_VECTORS_PATH);
DataSetIterator trainIter = getDataSetIterator(true, wordVectors,
batchSize, truncateReviewsToLength, rng);
DataSetIterator testIter = getDataSetIterator(false, wordVectors,
batchSize, truncateReviewsToLength, rng);
System.out.println("Starting training");
for (int i = 0; i < nEpochs; i++) {
wrapper.fit(trainIter);
trainIter.reset();
// 进行网络演化(进化)获得网络判定参数
Evaluation evaluation = net.evaluate(testIter);
testIter.reset();
System.out.println(evaluation.stats());
System.out.println("。。。。。。。第"+i+"。。。。。。。。步已经完成。。。。。。。。。。");
}
/*
* 保存模型
*/
saveNet("/home/zhoumeixu/model/cnn.model", net);
/*
* 加载模型
*/
ComputationGraph netload = loadNet("/home/zhoumeixu/model/cnn.model");
// 训练之后:加载一个句子并输出预测
String contentsFirstPas = "我的 手机 是 手机号码";
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator) testIter)
.loadSingleSentence(contentsFirstPas);
INDArray predictionsFirstNegative = netload
.outputSingle(featuresFirstNegative);
List labels = testIter.getLabels();
System.out.println("\n\nPredictions for first negative review:");
for (int i = 0; i < labels.size(); i++) {
System.out.println("P(" + labels.get(i) + ") = "
+ predictionsFirstNegative.getDouble(i));
}
}
private static DataSetIterator getDataSetIterator(boolean isTraining,
WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
Random rng) {
String path = isTraining ? "/home/zhoumeixu/model/rnnsenec.txt" : "/home/zhoumeixu/model/rnnsenectest.txt";
LabeledSentenceProvider sentenceProvider = new LabeledSentence(path,
rng);
return new CnnSentenceDataSetIterator.Builder()
.sentenceProvider(sentenceProvider).wordVectors(wordVectors)
.minibatchSize(minibatchSize)
.maxSentenceLength(maxSentenceLength)
.useNormalizedWordVectors(false).build();
}
public static void saveNet(String path, ComputationGraph net) {
ObjectOutputStream objectOutputStream = null;
try {
objectOutputStream = new ObjectOutputStream(new FileOutputStream(
path));
objectOutputStream.writeObject(net);
objectOutputStream.close();
} catch (Exception e) {
e.printStackTrace();
}
}
public static ComputationGraph loadNet(String path) {
ObjectInputStream objectInputStream = null;
ComputationGraph net = null;
try {
objectInputStream = new ObjectInputStream(new FileInputStream(path));
net = (ComputationGraph) objectInputStream.readObject();
objectInputStream.close();
} catch (Exception e) {
}
return net;
}
}