这篇是接上一篇的基于spark的LSTM字符模型,数据源是莎士比亚的段子,利用LSTM每次学习一个字符,然后写段子,代码如下
public class SparkLSTMCharacterExample { private static final Logger log = LoggerFactory.getLogger(SparkLSTMCharacterExample.class); private static Map, Character> INT_TO_CHAR = getIntToChar();//调用函数,返回索引和对应字符的map,可以先看后面的函数 private static Map , Integer> CHAR_TO_INT = getCharToInt();//调用函数,返回字符和对应索引的map,可以先看后面的函数 private static final int N_CHARS = INT_TO_CHAR.size();//计算索引数 private static int nOut = CHAR_TO_INT.size();//计算字符数 private static int exampleLength = 1000; //Length of each training example sequence to use//训练实例序列的长度为100 @Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1)//各种参数不再赘述了 private boolean useSparkLocal = true; @Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with") private int batchSizePerWorker = 8; //How many examples should be used per worker (executor) when fitting? @Parameter(names = "-numEpochs", description = "Number of epochs for training") private int numEpochs = 1; public static void main(String[] args) throws Exception { new SparkLSTMCharacterExample().entryPoint(args);//调用入口函数 } protected void entryPoint(String[] args) throws Exception { //Handle command line arguments JCommander jcmdr = new JCommander(this);//jCommander处理参数也不说了 try { jcmdr.parse(args); } catch (ParameterException e) { //User provides invalid input -> print the usage info jcmdr.usage(); try { Thread.sleep(500); } catch (Exception e2) { } throw e; } Random rng = new Random(12345);//随机生成器 int lstmLayerSize = 200; //Number of units in each GravesLSTM layer//LSTM层节点数量 int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters//截断式bptt中网络学习的长度 int nSamplesToGenerate = 4; //Number of samples to generate after each training epoch//每个训练步后生成的例子数量,这是要模仿写文章,所以有这个参数 int nCharactersToSample = 300; //Length of each sample to generate//生成例子的长度300,这是要模仿写文章,所以有这个参数 String generationInitialization = null; //Optional character initialization; a random character is used if null//初始化字符,这里是随机字符 // Above is Used to 'prime' the LSTM with a character sequence to continue/complete. // Initialization characters must all be in CharacterIterator.getMinimalCharacterSet() by default //Set up network configuration://设置网络 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .learningRate(0.1) .rmsDecay(0.95) .seed(12345) .regularization(true) .l2(0.001) .weightInit(WeightInit.XAVIER) .updater(Updater.RMSPROP) .list() .layer(0, new GravesLSTM.Builder().nIn(CHAR_TO_INT.size()).nOut(lstmLayerSize)//第一层是LSTM,输入大小独立字符数,输出大小是200,果然又是放大了好多,可见cnn是把节点越搞越小,rnn是把节点越搞越大 .activation("tanh").build()) .layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)//第二层还是LSTM层,输入输出节点都是200 .activation("tanh").build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation("softmax") //MCXENT + softmax for classification//输出层是RNN,由于是分类采用softmax作为激活函数,输入大小是200,输出和原始输入大小一致 .nIn(lstmLayerSize).nOut(nOut).build()) .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)//使用截断式bptt,截断长度为50,即正反向参数更新参考的长度都是50 .pretrain(false).backprop(true) .build(); //------------------------------------------------------------- //Set up the Spark-specific configuration//配置spark /* How frequently should we average parameters (in number of minibatches)? Averaging too frequently can be slow (synchronization + serialization costs) whereas too infrequently can result learning difficulties (i.e., network may not converge) */ int averagingFrequency = 3;//参数平均化的频率,3批平均一次 //Set up Spark configuration and context SparkConf sparkConf = new SparkConf();//使用spark本地模式 if (useSparkLocal) { sparkConf.setMaster("local[*]"); } sparkConf.setAppName("LSTM Character Example"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD trainingData = getTrainingData(sc);//获取数据得到训练RDD,跳到这个函数 //Set up the TrainingMaster. The TrainingMaster controls how learning is actually executed on Spark //Here, we are using standard parameter averaging //For details on these configuration options, see: https://deeplearning4j.org/spark#configuring//设置tm int examplesPerDataSetObject = 1;//每个DataSet对象有一个例子 ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)//构建tm .workerPrefetchNumBatches(2) //Asynchronously prefetch up to 2 batches//异步获取2批数据 .averagingFrequency(averagingFrequency)//参数平均化的频率是3 .batchSizePerWorker(batchSizePerWorker)//每个worker处理批的大小是8 .build(); SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(sc, conf, tm);//把参数传入spark的网络配置 sparkNetwork.setListeners(Collections. singletonList(new ScoreIterationListener(1)));//设置监听器,singletonList返回一个包含具体对象的不可变list //Do training, and then generate and print samples from network for (int i = 0; i < numEpochs; i++) {//按步数训练,生成并打印新写的例子,每步最后返回一个训练网络的副本 //Perform one epoch of training. At the end of each epoch, we are returned a copy of the trained network MultiLayerNetwork net = sparkNetwork.fit(trainingData);//定型网络 //Sample some characters from the network (done locally)//本地随机化一些字符 log.info("Sampling characters from network given initialization \"" + (generationInitialization == null ? "" : generationInitialization) + "\""); String[] samples = sampleCharactersFromNetwork(generationInitialization, net, rng, INT_TO_CHAR, nCharactersToSample, nSamplesToGenerate);//利用学习的生成新的例子,看下这个函数 for (int j = 0; j < samples.length; j++) {//打印随机化字符 log.info("----- Sample " + j + " -----"); log.info(samples[j]); } } //Delete the temp training files, now that we are done with them tm.deleteTempFiles(sc);//删除临时文件 log.info("\n\nExample complete"); } /** * Get the training data - a JavaRDD //注释说这个获取数据的方法是字符建模的特例,不是最佳实践 * Note that this approach for getting training data is a special case for this example (modelling characters), and * should not be taken as best practice for loading data (like CSV etc) in general. */ public static JavaRDDgetTrainingData(JavaSparkContext sc) throws IOException { //Get data. For the sake of this example, we are doing the following operations: // File -> String -> List (split into length "sequenceLength" characters) -> JavaRDD List-> JavaRDD //为了获取文件,我们从文件到字符串到序列长度的list再到RDD最终到DataSet的RDD list = getShakespeareAsList(exampleLength);//获取长度为1000的字符串list JavaRDD rawStrings = sc.parallelize(list);//并行化数据 Broadcast
charactersToSample是范例的字符数,就是除了初始化的字符,继续往下写多少个字,numSamples每个训练步完成后写几个例子,这里是4,也就是每个训练步完成写4个例子,每个例子300个字符
//Set up initialization. If no initialization: use a random character if (initialization == null) {//生成第一个字符 int randomCharIdx = rng.nextInt(intToChar.size()); initialization = String.valueOf(intToChar.get(randomCharIdx)); } //Create input for initialization INDArray initializationInput = Nd4j.zeros(numSamples, intToChar.size(), initialization.length());//生成一个三维数组,第一个参数是写段子的数量,第二个参数是段子词汇索引长度,第三个参数是初始化字符串长度,其实就是1个字符,这和训练样本的shape有所不同 char[] init = initialization.toCharArray();//把初始化的字符串转成字符数组 for (int i = 0; i < init.length; i++) {//遍历这个字符数组,如果不给初始字符串,其实只有一个字符,遍历一次 int idx = CHAR_TO_INT.get(init[i]);//找出初始化字符对应的索引 for (int j = 0; j < numSamples; j++) {//依次写4个例子 initializationInput.putScalar(new int[]{j, idx, i}, 1.0f);//通过把不同位置的索引置为1来写 } } StringBuilder[] sb = new StringBuilder[numSamples];//可变字符数组 for (int i = 0; i < numSamples; i++) sb[i] = new StringBuilder(initialization);//把初始化字符串放入每个可变字符数组,也就是开头都一样,后面写的不一样 //Sample from network (and feed samples back into input) one character at a time (for all samples) //Sampling is done in parallel here//并行写范文,一边写一遍反馈到输入,一次写一个字符 net.rnnClearPreviousState();//清理rnn之前的状态参数 INDArray output = net.rnnTimeStep(initializationInput);//initializationInput是网络的输入,本例中是单时间步,initializationInput的第一个参数代表批大小,第二个参数是输入大小也就是索引大小,第三个参数是1也就是单时间步,output是输出激活函数,和输入的维度一致,这其实相当于设定的预测模式,输入什么样的格式,输出一个什么样的格式 output = output.tensorAlongDimension(output.size(2) - 1, 1, 0); //Gets the last time step output//获取最后一个时间步的输出,tensorAlongDimension这个方法改变向量的维度,第一个参数是要改变向量的索引,后两个参数是要改成的维度,这里output.size(2)的意思是取output第二个维度的大小是1,改成1行0列的形式,也就是把每个范例弄成一行,相当于转置 for (int i = 0; i < charactersToSample; i++) {//开始写300个字符 //Set up next input (single time step) by sampling from previous output//根据之前的输出设置下一个输入 INDArray nextInput = Nd4j.zeros(numSamples, intToChar.size());//搞一个多为数组,行数4,列数是字符索引数 //Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input//输出是概率分布,根据这个产生样例,并添加到新的输入,具体看看下面代码 for (int s = 0; s < numSamples; s++) {//每个样例 double[] outputProbDistribution = new double[intToChar.size()];//搞一个字符索引长度的概率分布数组 for (int j = 0; j < outputProbDistribution.length; j++)//对每个概率 outputProbDistribution[j] = output.getDouble(s, j);//获取该样例该位置的概率 int sampledCharacterIdx = sampleFromDistribution(outputProbDistribution, rng);//函数sampleFromDistribution的作用是从分布中选出索引 nextInput.putScalar(new int[]{s, sampledCharacterIdx}, 1.0f); //Prepare next time step input//写入下一个输入 sb[s].append(intToChar.get(sampledCharacterIdx)); //Add sampled character to StringBuilder (human readable output)//根据索引获得字符添加到对应缓冲数组 } output = net.rnnTimeStep(nextInput); //Do one time step of forward pass//向前做一个时间步 } String[] out = new String[numSamples];//搞4个样例字符串数组 for (int i = 0; i < numSamples; i++) out[i] = sb[i].toString();//把每个样例转成字符串写进去并返回 return out; } /** * Given a probability distribution over discrete classes, sample from the distribution * and return the generated class index.//获取一个概率分布,从中抽样并返回产生类的索引 * * @param distribution Probability distribution over classes. Must sum to 1.0 */ private static int sampleFromDistribution(double[] distribution, Random rng) {//传入分布数组和随机生成器 double d = rng.nextDouble(); double sum = 0.0; for (int i = 0; i < distribution.length; i++) {//遍历分布数组,累加分布值知道大于等于随机数,这时返回索引,这说明先遇到较大概率的索引容易被选中 sum += distribution[i]; if (d <= sum) return i; } //Should never happen if distribution is a valid probability distribution throw new IllegalArgumentException("Distribution is invalid? d=" + d + ", sum=" + sum); } /** * A minimal character set, with a-z, A-Z, 0-9 and common punctuation etc */ private static char[] getValidCharacters() { ListvalidChars = new LinkedList<>();//搞一个字符list for (char c = 'a'; c <= 'z'; c++) validChars.add(c);//用a-z,A-Z,0-9和特殊字符填充list for (char c = 'A'; c <= 'Z'; c++) validChars.add(c); for (char c = '0'; c <= '9'; c++) validChars.add(c); char[] temp = {'!', '&', '(', ')', '?', '-', '\'', '"', ',', '.', ':', ';', ' ', '\n', '\t'}; for (char c : temp) validChars.add(c); char[] out = new char[validChars.size()];//搞一个新的字符数组 int i = 0; for (Character c : validChars) out[i++] = c;//把list的内容放到数组里 return out; } public static Map , Character> getIntToChar() { Map , Character> map = new HashMap<>();//搞一个map char[] chars = getValidCharacters();//获取有效字符,调用函数,可以先看函数 for (int i = 0; i < chars.length; i++) { map.put(i, chars[i]);//以索引为主键,字符为value填充map } return map; } public static Map , Integer> getCharToInt() { Map , Integer> map = new HashMap<>();//搞一个map char[] chars = getValidCharacters();//获取有效字符,调用函数,可以先看函数 for (int i = 0; i < chars.length; i++) {//遍历字符数组 map.put(chars[i], i);//以字符为主键,索引为value填充map } return map; } }