今下午看到deeplearning4j提供了好几个文本分类的例子,都是利用word2vec与lstm相结合的例子,今天下午在其上面的代码改了下,用自己的数据的格式,跑了下,记录下,我的基本数据格式如下。
1 你 可以 兰容网 问下 咨询师 里面 不仅 可以 咨询 任何 整形 问题 并且 全国各地 可以 帮你 查询 推荐 最好 医院 专家 公立医院 三甲医院 专科 整形医院 可以 查询 我 之前 就是 上面 问
1 小额贷款 要 吗
1 商家 @ 回应 十里洋场 江景会 所 订房 经理 电话 ️ 五星级 餐饮 加 五星级 k歌 包房 体验 打 8折 送 酒
1 可以 啊 微信同号
1 看头像 加微信
1 看头像 加微信
1 您 可以 加研欧 书院 微信公众号 里面 有 具体 收费 标 淮 本月 25日 书院 开展 试听 公开课 如果 您 有时间 可以来 试听
0 强烈推荐 4号 4号 店长 非常 满意 朋友 发型 合适 我的 开心 因为 自然 卷 每年 得 做 头发 去过 不少 理发店 接触 不少 发型师 到 目前为止 觉得 4号 ( 帅哥 一枚 ) 最 称心如意 会 根据 发质 情况 自己的 想法 给你 设计 适合 发型 不仅 专业 而且 非常 负责 细心 我 最喜欢 细心 人 因为 只有 这样 会 做出 好的 效果 这次 先 把 发型 做出来 4号 店长 建议 染色 那样 效果 会 更好 我 不用 那么 辛苦 坐 太久 贴心 是不是 总之 满意 染色 后的 效果 更 重要 4号 亲自动手 帮 我 做 发型 强烈推荐 大家 过来 找 4号 店长 下次 来 找 你 哦 4号 店长 支持 你 djehfjdushdjd
0 非常感谢 老师 给 小高 5 星 点评 想起 我 老师 一起 游览 情景 历历在目 时间 有时 像 一个 小偷 不知不觉 四月 已经 过完 2017年 已经 过去 三分之一 像 老话 连雨 不知 春 一 晴 方知夏 深 还没 来得及 好好 享受 温熏 春光 夏日 风 已 从 远处 徐徐 而来 夏天 青岛 最 美的 青岛 避暑 理想 之 地 欢迎 老师 有机会 夏天 再来 青岛 我的 手机号 即 微信号 到时 记得 微信 小高 我 来 帮 您 订 酒店 青岛 小高 祝 老师 身体健康 阖家欢乐 万事如意
0 我 来 时候 他家 根本 拒绝 兑换 也是 遇 得到 气死人
0 商家回应 尊敬 客人 感谢您 抽出 宝贵 时间 给 我们 评价 心瑞 国际 月子会所 来源于 台湾 拥有 26年 台式 母婴护理 经验 聚集 最 专业 最 精湛 台湾 护理 技术 团队 精心 定制 专属 护理服务 秉承 规范 操作 细致入微 精益求精 服务理念 为 产 后妈 咪 提供 从 饮食 护理 康复 早教 心理健康 等 全方位 贴身 服务模式 给 孕产 期间 家庭 全方位 专业 照护 舒适 体验 会所 紧邻 国内 最好 医院 协和医院 产后 妈咪 宝宝 都有 坚实 医疗保障 我们 免费提供 健身会所 游泳馆 给 入住 客人 家属 使用 , 我们 会 不定期 举办 丰富多彩 活动 让 更多 孕妈咪 们 了解 孕期 保健知识 新生儿 喂养 知识 哦 非常 期待 下次 妈咪 见面 哦 心瑞 国际 月子会所 全体员工 祝 服务热线 座机号码 请关注 微信号 微信账号
0 商家回应 亲爱 贵宾 感谢您 襄阳 巴厘岛 休闲 度假酒店 肯定 支持 酒店 休闲会所 主要 提供 休闲 洗浴 游泳 桑拿 干湿 蒸 足疗 按摩 等 项目 如此 次 体验 没能 让 您 满意 我们 表示 深深 歉意 我们 足疗 专业 休闲 手法 所有 技师 素质 高 够 专业 各 具 魅力 服务 更好 巴厘岛 一切 以 您 最 舒适 休闲 方式 为先 您 满意 我们 继续 进步 动力 酒店 全体员工 期待 您 再次光临
0 人 直接 不行 必须 每人 点 一份 主食 我 那 钱 我 给你 东西 不用 上了 吃 不掉 浪费 服务员 b 那 再 来个 茶 吧 饭 毕 服务员 c 过来 我们 要 续 杯茶 他 可能 听懂 把 bill 拿走 一会 送来 一个新 上面 加 一杯 茶 我们 收费 那 不要 他 直接 气呼呼 拿走 新 bill 回来 时候 把 新 打印 去掉 杯茶 bill 直接 摔 桌子 然后 回头 走 我 擦 你 真 牛逼 我 走 这么多 英联邦 国家 地区 别人 一看 中国 人 客客气气 何况 还是 顾客 你 妹的
0 维权 群 怎么 加 我 被 套路
UTF-8
0.8.0
0.8.0
0.8.0
1.8
1.8
org.nd4j
nd4j-native
${nd4j.version}
org.deeplearning4j
deeplearning4j-core
${dl4j.version}
org.deeplearning4j
deeplearning4j-nlp
${dl4j.version}
org.datavec
datavec-api
${datavec.version}
com.meituan
nlp-utils
0.0.1-SNAPSHOT
package com.dianping.recurrent.adx;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.dianping.recurrent.util.PathUtils;
import com.meituan.nlp.util.TextUtil;
import com.meituan.nlp.util.WordUtil;
public class PrepareWordVector {
private static Logger log = LoggerFactory.getLogger(PrepareWordVector.class);
private static String datapath=PathUtils.INPUT_ADX;
public static void transtomodel(String input, String outputword22vec,String outputrnn) throws IOException {
BufferedReader reader = null;
BufferedWriter writerword2vec = null;
BufferedWriter writerrnn = null;
reader = new BufferedReader(new InputStreamReader(new FileInputStream(
input)));
writerword2vec = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(outputword22vec)));
writerrnn = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(outputrnn)));
String line = reader.readLine();
while (line != null) {
String label = line.split("\t")[0];
String content = line.split("\t")[1];
if (StringUtils.isNotBlank(content)) {
String result = WordUtil
.getAdSegmentNotURL(
WordUtil.replaceAllADXSynonyms(TextUtil.fan2Jian(WordUtil
.converToDigitStr(WordUtil
.replaceAdxAll(content
.toLowerCase())))))
.stream().collect(Collectors.joining(" "));
writerrnn.write(label + "\t" + result + "\n");
writerword2vec.write(result + "\n");
}
line = reader.readLine();
}
reader.close();
writerrnn.close();
writerword2vec.close();
}
public static void trainword2vec(String inputpath, String outputpath)
throws IOException {
SentenceIterator iter = new BasicLineIterator(inputpath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
log.info("build word2vec will start");
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(5)
.layerSize(100).seed(42).windowSize(20).iterate(iter)
.tokenizerFactory(t).build();
log.info("Fitting Word2Vec model....");
vec.fit();
log.info("Writing word vectors to text file....");
// Write word vectors to file
WordVectorSerializer.writeWordVectors(vec, outputpath);
}
public static void main(String[] args) throws IOException {
//transtomodel(datapath,"adx/wordvecsence.txt","adx/rnnsenec.txt");
trainword2vec("adx/wordvecsence.txt","adx/word2vec.model");
}
}
package com.dianping.recurrent.adx;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
public class ADXIterator implements DataSetIterator {
private final WordVectors wordVectors;
private final int batchSize;
private final int vectorSize;
private final int truncateLength;
private int maxLength;
private final String dataDirectory;
private final List<Pair<String, List<String>>> categoryData = new ArrayList<>();
private int cursor = 0;
private int totalNews = 0;
private final TokenizerFactory tokenizerFactory;
private int newsPosition = 0;
private final List<String> labels;
private int currCategory = 0;
private ADXIterator(String dataDirectory, WordVectors wordVectors,
int batchSize, int truncateLength, boolean train,
TokenizerFactory tokenizerFactory) {
this.dataDirectory = dataDirectory;
this.batchSize = batchSize;
this.vectorSize = wordVectors.getWordVector(wordVectors.vocab()
.wordAtIndex(0)).length;
this.wordVectors = wordVectors;
this.truncateLength = truncateLength;
this.tokenizerFactory = tokenizerFactory;
this.populateData(train);
this.labels = new ArrayList<>();
for (int i = 0; i < 2; i++) {
this.labels.add(String.valueOf(i));
}
}
public static Builder Builder() {
return new Builder();
}
@Override
public DataSet next(int num) {
if (cursor >= this.totalNews)
throw new NoSuchElementException();
try {
return nextDataSet(num);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private DataSet nextDataSet(int num) throws IOException {
// Loads news into news list from categoryData List along with category
// of each news
List<String> news = new ArrayList<>(num);
int[] category = new int[num];
// private final List
>> categoryData for (int i = 0; i < num && cursor < totalExamples(); i++) {
if (currCategory < categoryData.size()) {
news.add(this.categoryData.get(currCategory).getValue()
.get(newsPosition));
category[i] = Integer.parseInt(this.categoryData.get(
currCategory).getKey());
currCategory++;
cursor++;
} else {
currCategory = 0;
newsPosition++;
i--;
}
}
// Second: tokenize news and filter out unknown words
List<List<String>> allTokens = new ArrayList<>(news.size());
maxLength = 0;
for (String s : news) {
List<String> tokens = tokenizerFactory.create(s).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for (String t : tokens) {
if (wordVectors.hasWord(t))
tokensFiltered.add(t);
}
allTokens.add(tokensFiltered);
maxLength = Math.max(maxLength, tokensFiltered.size());
}
// If longest news exceeds 'truncateLength': only take the first
// 'truncateLength' words
// System.out.println("maxLength : " + maxLength);
if (maxLength > truncateLength)
maxLength = truncateLength;
// Create data for training
// Here: we have news.size() examples of varying lengths
INDArray features = Nd4j.create(news.size(), vectorSize, maxLength);
INDArray labels = Nd4j.create(news.size(), this.categoryData.size(),
maxLength); // Three labels: Crime, Politics, Bollywood
// Because we are dealing with news of different lengths and only one
// output at the final time step: use padding arrays
// Mask arrays contain 1 if data is present at that time step for that
// example, or 0 if data is just padding
INDArray featuresMask = Nd4j.zeros(news.size(), maxLength);
INDArray labelsMask = Nd4j.zeros(news.size(), maxLength);
int[] temp = new int[2];
for (int i = 0; i < news.size(); i++) {
List<String> tokens = allTokens.get(i);
temp[0] = i;
// Get word vectors for each word in news, and put them in the
// training data
for (int j = 0; j < tokens.size() && j < maxLength; j++) {
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[] { point(i), all(), point(j) },
vector);
temp[1] = j;
featuresMask.putScalar(temp, 1.0);
}
int idx = category[i];
int lastIdx = Math.min(tokens.size(), maxLength);
labels.putScalar(new int[] { i, idx, lastIdx - 1 }, 1.0);
labelsMask.putScalar(new int[] { i, lastIdx - 1 }, 1.0);
}
DataSet ds = new DataSet(features, labels, featuresMask, labelsMask);
return ds;
}
public INDArray loadFeaturesFromFile(File file, int maxLength)
throws IOException {
String news = FileUtils.readFileToString(file);
return loadFeaturesFromString(news, maxLength);
}
public INDArray loadFeaturesFromString(String reviewContents, int maxLength) {
List<String> tokens = tokenizerFactory.create(reviewContents)
.getTokens();
List<String> tokensFiltered = new ArrayList<>();
for (String t : tokens) {
if (wordVectors.hasWord(t))
tokensFiltered.add(t);
}
int outputLength = Math.max(maxLength, tokensFiltered.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);
for (int j = 0; j < tokens.size() && j < maxLength; j++) {
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[] { point(0), all(), point(j) },
vector);
}
return features;
}
/*
* This function loads news headlines from files stored in resources into
* categoryData List.
*/
private void populateData(boolean train) {
String name = train ? "rnnsenec.txt"
: "rnnsenectest.txt";
String curFileName=this.dataDirectory+name;
BufferedReader currBR = null;
File currFile = new File(curFileName);
try {
currBR = new BufferedReader((new FileReader(currFile)));
String tempCurrLine = "";
List<String> tempListnorme = new ArrayList<>();
List<String> tempListneg = new ArrayList<>();
while ((tempCurrLine = currBR.readLine()) != null) {
String[] lines = tempCurrLine.split("\t");
String label = lines[0];
if ("1".equalsIgnoreCase(label)) {
tempListnorme.add(lines[1]);
} else if("0".equalsIgnoreCase(label)) {
tempListneg.add(lines[1]);
}
this.totalNews++;
}
currBR.close();
Pair<String, List<String>> tempPairnore = Pair.of("1",
tempListnorme);
this.categoryData.add(tempPairnore);
Pair<String, List<String>> tempPair = Pair.of("0", tempListneg);
this.categoryData.add(tempPair);
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public int totalExamples() {
return this.totalNews;
}
@Override
public int inputColumns() {
return vectorSize;
}
@Override
public int totalOutcomes() {
return this.categoryData.size();
}
@Override
public void reset() {
cursor = 0;
newsPosition = 0;
currCategory = 0;
}
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public int batch() {
return batchSize;
}
@Override
public int cursor() {
return cursor;
}
@Override
public int numExamples() {
return totalExamples();
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException();
}
@Override
public List<String> getLabels() {
return this.labels;
}
@Override
public boolean hasNext() {
return cursor < numExamples();
}
@Override
public DataSet next() {
return next(batchSize);
}
@Override
public void remove() {
}
@Override
public DataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("Not implemented");
}
public int getMaxLength() {
return this.maxLength;
}
public static class Builder {
private String dataDirectory;
private WordVectors wordVectors;
private int batchSize;
private int truncateLength;
TokenizerFactory tokenizerFactory;
private boolean train;
Builder() {
}
public ADXIterator.Builder dataDirectory(String dataDirectory) {
this.dataDirectory = dataDirectory;
return this;
}
public ADXIterator.Builder wordVectors(WordVectors wordVectors) {
this.wordVectors = wordVectors;
return this;
}
public ADXIterator.Builder batchSize(int batchSize) {
this.batchSize = batchSize;
return this;
}
public ADXIterator.Builder truncateLength(int truncateLength) {
this.truncateLength = truncateLength;
return this;
}
public ADXIterator.Builder train(boolean train) {
this.train = train;
return this;
}
public ADXIterator.Builder tokenizerFactory(
TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
return this;
}
public ADXIterator build() {
return new ADXIterator(dataDirectory, wordVectors, batchSize,
truncateLength, train, tokenizerFactory);
}
public String toString() {
return "org.deeplearning4j.examples.recurrent.ProcessNews.NewsIterator.Builder(dataDirectory="
+ this.dataDirectory
+ ", wordVectors="
+ this.wordVectors
+ ", batchSize="
+ this.batchSize
+ ", truncateLength="
+ this.truncateLength + ", train=" + this.train + ")";
}
}
}
package com.dianping.recurrent.adx;
import java.io.File;
import java.io.IOException;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class TrainAdxRnnModel {
public static String userDirectory = "";
public static String DATA_PATH = "";
public static String WORD_VECTORS_PATH = "";
public static WordVectors wordVectors;
private static Tokenizer tokenizerFactory;
public static void main(String[] args) throws IOException {
DATA_PATH = "adx/";
WORD_VECTORS_PATH = "adx/word2vec.model";
int batchSize = 6; // Number of examples in each minibatch
int nEpochs = 10; // 训练次数
int truncateReviewsToLength = 300; // 文本最大长度
wordVectors = WordVectorSerializer.fromPair(WordVectorSerializer.loadTxt(new File(WORD_VECTORS_PATH)));
TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
ADXIterator iTrain = new ADXIterator.Builder().dataDirectory(DATA_PATH)
.wordVectors(wordVectors).batchSize(batchSize)
.truncateLength(truncateReviewsToLength)
.tokenizerFactory(tokenizerFactory).train(true).build();
ADXIterator iTest = new ADXIterator.Builder().dataDirectory(DATA_PATH)
.wordVectors(wordVectors).batchSize(batchSize)
.truncateLength(truncateReviewsToLength)
.tokenizerFactory(tokenizerFactory).train(false).build();
int inputNeurons = wordVectors.getWordVector(wordVectors.vocab()
.wordAtIndex(0)).length; // 100 in our case
int outputs = iTrain.getLabels().size();
tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
System.out.println("inputNeurons is :" + inputNeurons);
// Set up network configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(
OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.updater(Updater.RMSPROP)
.regularization(true)
.l2(1e-5)
.weightInit(WeightInit.XAVIER)
.gradientNormalization(
GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0)
.learningRate(0.0018)
.list()
.layer(0,
new GravesLSTM.Builder().nIn(inputNeurons).nOut(200)
.activation(Activation.SOFTSIGN).build())
.layer(1,
new RnnOutputLayer.Builder()
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.nIn(200).nOut(outputs).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//设置没两百步观察数据情况
net.setListeners(new ScoreIterationListener(200));
System.out.println("Starting training");
for (int i = 0; i < nEpochs; i++) {
net.fit(iTrain);
iTrain.reset();
System.out
.println("Epoch " + i + " complete. Starting evaluation:");
// Run evaluation. This is on 25k reviews, so can take some time
Evaluation evaluation = new Evaluation();
while (iTest.hasNext()) {
DataSet t = iTest.next();
INDArray features = t.getFeatureMatrix();
INDArray lables = t.getLabels();
// System.out.println("labels : " + lables);
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = net.output(features, false);
// System.out.println("predicted : " + predicted);
evaluation.evalTimeSeries(lables, predicted, outMask);
}
iTest.reset();
System.out.println(evaluation.stats());
}
ModelSerializer.writeModel(net, "adx/" + "NewsModel.net", true);
System.out.println("----- Example complete -----");
}
}
Starting training
Epoch 0 complete. Starting evaluation:
Examples labeled as 0 classified by model as 0: 8 times
Examples labeled as 0 classified by model as 1: 1 times
Examples labeled as 1 classified by model as 1: 9 times
==========================Scores========================================
Accuracy: 0.9444
Precision: 0.95
Recall: 0.9444
F1 Score: 0.9472
========================================================================
Epoch 1 complete. Starting evaluation:
Examples labeled as 0 classified by model as 0: 9 times
Examples labeled as 1 classified by model as 1: 9 times
==========================Scores========================================
Accuracy: 1
Precision: 1
Recall: 1
F1 Score: 1
========================================================================
Epoch 2 complete. Starting evaluation:
Examples labeled as 0 classified by model as 0: 9 times
Examples labeled as 1 classified by model as 1: 9 times
==========================Scores========================================
Accuracy: 1
Precision: 1
Recall: 1
F1 Score: 1
========================================================================
Epoch 3 complete. Starting evaluation:
Examples labeled as 0 classified by model as 0: 9 times
Examples labeled as 1 classified by model as 1: 9 times
==========================Scores========================================
Accuracy: 1
Precision: 1
Recall: 1
F1 Score: 1
========================================================================