深度学习--写诗

深度学习–写诗

在一次培训中有一个深度学习写诗的例子,觉得很不错,适合入门,加上部分个人理解记录下来。
使用深度学习写诗可以使用RNN/LSTM等具有记忆能力的神经网络结构实现。通过读取提供的诗样本来实现诗结构的记忆,然后通过类似默写(由上一个字符推测下一个字符)的方式来训练写诗能力。

测试记忆训练情况

通过dl4j的一个例子来训练英文字符串的记忆,测试结果是训练后的默写情况。完整代码地址:https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/basic/BasicRNNExample.java
训练代码稍作修改并添加部分注释:

public class BasicRNNExample {

	// define a sentence to learn.
    // Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
	private static final char[] LEARNSTRING = "Der Cottbuser Postkutscher putzt den Cottbuser Postkutschkasten.".toCharArray();

	// a list of all possible characters
	private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();

	// RNN dimensions
	private static final int HIDDEN_LAYER_WIDTH = 50;//隐藏层节点数
	private static final int HIDDEN_LAYER_CONT = 2;//隐藏层数
    //private static final Random r = new Random(7894);

	public static void main(String[] args) {

		// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
		//获取LEARNSTRING中存在的字符,去掉重复,也就是训练中可能预测到的字符
		LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
		for (char c : LEARNSTRING)
			LEARNSTRING_CHARS.add(c);
		LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);

		// some common parameters
		NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
		builder.seed(123);
		builder.biasInit(0);//初始化偏移量为0
		builder.miniBatch(false);
		builder.updater(new RmsProp(0.001));//设置学习率
		builder.weightInit(WeightInit.XAVIER);//初始化权重

		ListBuilder listBuilder = builder.list();

		//定义隐藏层
		// first difference, for rnns we need to use LSTM.Builder
		for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
			LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
			hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);//定义输入节点数
			hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);//定义输出节点数
			// adopted activation function from LSTMCharModellingExample
			// seems to work well with RNNs
			hiddenLayerBuilder.activation(Activation.TANH);//设置激活函数
			listBuilder.layer(i, hiddenLayerBuilder.build());
		}

		//定义输出层
		// we need to use RnnOutputLayer for our RNN
		RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
		// softmax normalizes the output neurons, the sum of all outputs is 1
		// this is required for our sampleFromDistribution-function
		outputLayerBuilder.activation(Activation.SOFTMAX);
		outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
		outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
		listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());

		// finish builder
		listBuilder.pretrain(false);//预训练
		listBuilder.backprop(true);//反向传播

		// create network
		MultiLayerConfiguration conf = listBuilder.build();
		MultiLayerNetwork net = new MultiLayerNetwork(conf);
		net.init();
		net.setListeners(new ScoreIterationListener(10));

		/*
		 * CREATE OUR TRAINING DATA
		 */
		// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
		// SEQUENCE_POSITION
		INDArray input = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length);
		INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length);
		// loop through our sample-sentence
		int samplePos = 0;
		for (char currentChar : LEARNSTRING) {
			// small hack: when currentChar is the last, take the first char as
			// nextChar - not really required. Added to this hack by adding a starter first character.
			char nextChar = LEARNSTRING[(samplePos + 1) % (LEARNSTRING.length)];
			// input neuron for current-char is 1 at "samplePos"
			input.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(currentChar), samplePos }, 1);
			// output neuron for next-char is 1 at "samplePos"
			//以当前字符的下一个作为标签,即通过当前字符训练记忆下一个字符
			labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
			samplePos++;
		}
		DataSet trainingData = new DataSet(input, labels);

		// some epochs
		for (int epoch = 0; epoch < 2000; epoch++) {
			if ((epoch%20)==19)
				System.out.println("Epoch " + (epoch+1));

			// train the data
			net.fit(trainingData);

			// clear current stance from the last example
			net.rnnClearPreviousState();

			if ((epoch%20)!=19)
				continue;
			//以下是测试训练后的默写情况
			System.out.print(LEARNSTRING[0]);
			// put the first character into the rrn as an initialisation
			INDArray testInit = Nd4j.zeros(1,LEARNSTRING_CHARS_LIST.size(), 1);
			testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(LEARNSTRING[0]), 1);

			// run one step -> IMPORTANT: rnnTimeStep() must be called, not
			// output()
			// the output shows what the net thinks what should come next
			INDArray output = net.rnnTimeStep(testInit);

			// now the net should guess LEARNSTRING.length more characters
            for (int n=1;n<LEARNSTRING.length;n++) {

                // first process the last output of the network to a concrete
                // neuron, the neuron with the highest output has the highest
                // chance to get chosen
                int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);

                // print the chosen output
                System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx));

                // use the last output as input
                INDArray nextInput = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), 1);
                nextInput.putScalar(sampledCharacterIdx, 1);
                output = net.rnnTimeStep(nextInput);
            }
			System.out.print("\n");
			System.out.println(LEARNSTRING);
		}
	}
}

把英文字符串改为中文测试一下试试,如改为:
“*我们来试试用RNN记忆一段文字。”,字符串前添加一些特殊字符以增加训练效果,具体要根据字符串长度调整一些训练参数。也可以多段文字代替一段文字测试试试。
训练多行记忆测试:

public class BasicRNNExample3 {

   // define a sentence to learn.
   // Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
   private static final char[][] LEARNSTRING = new char[][]{
   		"我们来试试用RNN记忆一段文字。".toCharArray(),
   		"然后我们记忆另一段文字试试。".toCharArray()
   		//, "我们能不能记住第三段文字呢。".toCharArray()//试试添加第三行训练会出现什么情况,重点是本行和第一行“我们”是重复的
   };

   // a list of all possible characters
   private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();

   // RNN dimensions
   private static final int HIDDEN_LAYER_WIDTH = 50;
   private static final int HIDDEN_LAYER_CONT = 2;

   public static void main(String[] args) {

   	// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
   	LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
   	for (char[] str : LEARNSTRING)
   		for (char c: str)
   			LEARNSTRING_CHARS.add(c);
   	LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);

   	// some common parameters
   	NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
   	builder.seed(123);
   	builder.biasInit(0);
   	builder.miniBatch(false);
   	builder.updater(new RmsProp(0.002));
   	builder.weightInit(WeightInit.XAVIER);

   	ListBuilder listBuilder = builder.list();

   	// first difference, for rnns we need to use LSTM.Builder
   	for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
   		LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
   		hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);
   		hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
   		// adopted activation function from LSTMCharModellingExample
   		// seems to work well with RNNs
   		hiddenLayerBuilder.activation(Activation.TANH);
   		listBuilder.layer(i, hiddenLayerBuilder.build());
   	}

   	// we need to use RnnOutputLayer for our RNN
   	RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
   	// softmax normalizes the output neurons, the sum of all outputs is 1
   	// this is required for our sampleFromDistribution-function
   	outputLayerBuilder.activation(Activation.SOFTMAX);
   	outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
   	outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
   	listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());

   	// finish builder
   	listBuilder.pretrain(false);
   	listBuilder.backprop(true);

   	// create network
   	MultiLayerConfiguration conf = listBuilder.build();
   	MultiLayerNetwork net = new MultiLayerNetwork(conf);
   	net.init();
   	net.setListeners(new ScoreIterationListener(5));

   	/*
   	 * CREATE OUR TRAINING DATA
   	 */
   	// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
   	// SEQUENCE_POSITION
   	List<DataSet> inputs=new ArrayList<>();
   	for (char[] str: LEARNSTRING)
   	{
   		INDArray input = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
   		INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
   		// loop through our sample-sentence
   		int samplePos = 0;
   		for (char currentChar : str) {
   			// small hack: when currentChar is the last, take the first char as
   			// nextChar - not really required. Added to this hack by adding a starter first character.
   			char nextChar = str[(samplePos + 1) % (str.length)];
   			// input neuron for current-char is 1 at "samplePos"
   			input.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(currentChar), samplePos }, 1);
   			// output neuron for next-char is 1 at "samplePos"
   			labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
   			samplePos++;
   		}
   		DataSet trainingData = new DataSet(input, labels);
   		inputs.add(trainingData);
   	}

   	// some epochs
   	for (int epoch = 0; epoch < 200; epoch++) {

   		//System.out.println("Epoch " + epoch);

   		for (DataSet trainingData: inputs)
   		{
   			// train the data
   			net.fit(trainingData);
   
   			// clear current stance from the last example
   			net.rnnClearPreviousState();
   		}

   		if ((epoch&3)==3)
   		{
   			System.out.println("Epoch " + epoch);
   			int right=0, total=0;
   			for (char[] str: LEARNSTRING)
   			{
   				total+=str.length-1;
   				System.out.print(str[0]);
   				// put the first character into the rrn as an initialisation
   				INDArray testInit = Nd4j.zeros(1,LEARNSTRING_CHARS_LIST.size(), 1);
   				testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(str[0]), 1);
   	
   				// run one step -> IMPORTANT: rnnTimeStep() must be called, not
   				// output()
   				// the output shows what the net thinks what should come next
   				INDArray output = net.rnnTimeStep(testInit);
   	
   				// now the net should guess LEARNSTRING.length more characters
   	            for (int i=0;i<str.length-1;i++) {
   	
   	                // first process the last output of the network to a concrete
   	                // neuron, the neuron with the highest output has the highest
   	                // chance to get chosen
   	                int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);
   	
   	                // print the chosen output
   	                System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx));
   	                if (str[(i+1)%str.length]==LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx))
   	                	right++;
   	
   	                // use the last output as input
   	                INDArray nextInput = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), 1);
   	                nextInput.putScalar(sampledCharacterIdx, 1);
   	                output = net.rnnTimeStep(nextInput);
   	
   	            }
   	            System.out.println();
   	            net.rnnClearPreviousState();
   			}
   			System.out.println(right+"/"+total);
   		}
   	}
   }
}

初步训练写诗

这里以训练写5言诗为例
1.准备训练用的5言诗,放在/rnn/poems0.txt文档

寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。
晚霞聊自怡,初晴弥可喜。日晃百花色,风动千林翠。池鱼跃不同,园鸟声还异。寄言博通者,知予物外志。
一朝春夏改,隔夜鸟花迁。阴阳深浅叶,晓夕重轻烟。哢莺犹响殿,横丝正网天。珮高兰影接,绶细草纹连。
碧鳞惊棹侧,玄燕舞檐前。何必汾阳处,始复有山泉。
夏律昨留灰,秋箭今移晷。峨嵋岫初出,洞庭波渐起。桂白发幽岩,菊黄开灞涘。运流方可叹,含毫属微理。
寒惊蓟门叶,秋发小山枝。松阴背日转,竹影避风移。提壶菊花岸,高兴芙蓉池。欲知凉气早,巢空燕不窥。
爽气浮丹阙,秋光澹紫宫。衣碎荷疏影,花明菊点丛。袍轻低草露,盖侧舞松风。散岫飘云叶,迷路飞烟鸿。
砌冷兰凋佩,闺寒树陨桐。别鹤栖琴里,离猿啼峡中。落野飞星箭,弦虚半月弓。芳菲夕雾起,暮色满房栊。
山亭秋色满,岩牖凉风度。疏兰尚染烟,残菊犹承露。古石衣新苔,新巢封古树。历览情无极,咫尺轮光暮。
碧原开雾隰,绮岭峻霞城。烟峰高下翠,日浪浅深明。斑红妆蕊树,圆青压溜荆。迹岩劳傅想,窥野访莘情。巨川何以济,舟楫伫时英。
韶光开令序,淑气动芳年。驻辇华林侧,高宴柏梁前。紫庭文珮满,丹墀衮绂连。九夷簉瑶席,五狄列琼筵。
娱宾歌湛露,广乐奏钧天。清尊浮绿醑,雅曲韵朱弦。粤余君万国,还惭抚八埏。庶几保贞固,虚己厉求贤。
烈烈寒风起,惨惨飞云浮。霜浓凝广隰,冰厚结清流。金鞍移上苑,玉勒骋平畴。旌旗四望合,罝罗一面求。
楚踣争兕殪,秦亡角鹿愁。兽忙投密树,鸿惊起砾洲。骑敛原尘静,戈回岭日收。心非洛汭逸,意在渭滨游。禽荒非所乐,抚辔更招忧。
披襟眺沧海,凭轼玩春芳。积流横地纪,疏派引天潢。仙气凝三岭,和风扇八荒。拂潮云布色,穿浪日舒光。
照岸花分彩,迷云雁断行。怀卑运深广,持满守灵长。有形非易测,无源讵可量。洪涛经变野,翠岛屡成桑。
冻云宵遍岭,素雪晓凝华。入牖千重碎,迎风一半斜。不妆空散粉,无树独飘花。萦空惭夕照,破彩谢晨霞。
暮景斜芳殿,年华丽绮宫。寒辞去冬雪,暖带入春风。阶馥舒梅素,盘花卷烛红。共欢新故岁,迎送一宵中。
岁阴穷暮纪,献节启新芳。冬尽今宵促,年开明日长。冰消出镜水,梅散入风香。对此欢终宴,倾壶待曙光。
和气吹绿野,梅雨洒芳田。新流添旧涧,宿雾足朝烟。雁湿行无次,花沾色更鲜。对此欣登岁,披襟弄五弦。
翠楼含晓雾,莲峰带晚云。玉叶依岩聚,金枝触石分。横天结阵影,逐吹起罗文。非得阳台下,空将惑楚君。

2.在上面训练代码的基础上更改来训练写诗

public class BasicRNNExample5 {

	// define a sentence to learn.
    // Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
	private static final List<char[]> LEARNSTRING = new ArrayList<>();

	// a list of all possible characters
	private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();

	// RNN dimensions
	private static final int HIDDEN_LAYER_WIDTH = 50;
	private static final int HIDDEN_LAYER_CONT = 2;

	public static void main(String[] args) throws IOException {
        // Gets Path to Text file
		try(BufferedReader reader=new BufferedReader(new InputStreamReader(
				BasicRNNExample5.class.getResourceAsStream("/rnn/poems0.txt"), StandardCharsets.UTF_8)))
		{
			while(true)
			{
				String line=reader.readLine();
				if (line==null)
					break;
				LEARNSTRING.add(line.toCharArray());
			}
		}

		// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
		LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
		for (char[] str : LEARNSTRING)
			for (char c: str)
				LEARNSTRING_CHARS.add(c);
		LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);
		System.out.println("vocabulary size: "+LEARNSTRING_CHARS_LIST.size());
		System.out.println("sentences: "+LEARNSTRING.size());

		// some common parameters
		NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
		builder.seed(123);
		builder.biasInit(0);
		builder.miniBatch(false);
		builder.updater(new RmsProp(0.002));
		builder.weightInit(WeightInit.XAVIER);

		ListBuilder listBuilder = builder.list();

		// first difference, for rnns we need to use LSTM.Builder
		for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
			LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
			hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);
			hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
			// adopted activation function from LSTMCharModellingExample
			// seems to work well with RNNs
			hiddenLayerBuilder.activation(Activation.TANH);
			listBuilder.layer(i, hiddenLayerBuilder.build());
		}

		// we need to use RnnOutputLayer for our RNN
		RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
		// softmax normalizes the output neurons, the sum of all outputs is 1
		// this is required for our sampleFromDistribution-function
		outputLayerBuilder.activation(Activation.SOFTMAX);
		outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
		outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
		listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());

		// finish builder
		listBuilder.pretrain(false);
		listBuilder.backprop(true);

		// create network
		MultiLayerConfiguration conf = listBuilder.build();
		MultiLayerNetwork net = new MultiLayerNetwork(conf);
		net.init();
		net.setListeners(new ScoreIterationListener(LEARNSTRING.size()));

		/*
		 * CREATE OUR TRAINING DATA
		 */
		// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
		// SEQUENCE_POSITION
		List<DataSet> inputs=new ArrayList<>();
		for (char[] str: LEARNSTRING)
		{
			INDArray input = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
			INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
			// loop through our sample-sentence
			int samplePos = 0;
			for (char currentChar : str) {
				// small hack: when currentChar is the last, take the first char as
				// nextChar - not really required. Added to this hack by adding a starter first character.
				char nextChar = str[(samplePos + 1) % (str.length)];
				// input neuron for current-char is 1 at "samplePos"
				input.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(currentChar), samplePos }, 1);
				// output neuron for next-char is 1 at "samplePos"
				labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
				samplePos++;
			}
			DataSet trainingData = new DataSet(input, labels);
			inputs.add(trainingData);
		}

		// some epochs
		for (int epoch = 0; epoch < 200; epoch++) {

			for (DataSet trainingData: inputs)
			{
				// train the data
				net.fit(trainingData);
	
				// clear current stance from the last example
				net.rnnClearPreviousState();
			}

			if ((epoch&3)!=3)
				continue;
			
			System.out.println("Epoch " + epoch);
			//以'春','夏','秋','冬','花','月'开头作诗,每行24个字符
			for (char str: new char[]{'春','夏','秋','冬','花','月'})
			{
				System.out.print(str);
				// put the first character into the rrn as an initialisation
				INDArray testInit = Nd4j.zeros(1,LEARNSTRING_CHARS_LIST.size(), 1);
				testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(str), 1);
	
				// run one step -> IMPORTANT: rnnTimeStep() must be called, not
				// output()
				// the output shows what the net thinks what should come next
				INDArray output = net.rnnTimeStep(testInit);
	
				// now the net should guess LEARNSTRING.length more characters
	            for (int i=0;i<23;i++) {
	
	                // first process the last output of the network to a concrete
	                // neuron, the neuron with the highest output has the highest
	                // chance to get chosen
	                int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);
	
	                // print the chosen output
	                System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx));
	
	                // use the last output as input
	                INDArray nextInput = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), 1);
	                nextInput.putScalar(sampledCharacterIdx, 1);
	                output = net.rnnTimeStep(nextInput);
	
	            }
	            System.out.println();
	            net.rnnClearPreviousState();
			}
		}
	}
}

从运行的效果来看,并非训练得越多,作诗效果越好。中间能出来一些意向不到的好句子,但是充分训练后,记忆的效果出来了,变成默写了。
可以引入一些手段防止默写,如:
WORD2VEC+全连接,效果不好;
随机WORD2VEC,效果有所提升,代码参考如下:

public class BasicRNNExample8 {

	// define a sentence to learn.
    // Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
	private static final List<char[]> LEARNSTRING = new ArrayList<>();

	// a list of all possible characters
	private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();

	// RNN dimensions
	private static final int HIDDEN_LAYER_WIDTH = 50;
	private static final int HIDDEN_LAYER_CONT = 3;
	private static final double nonZeroBias = 1;
	private static final double dropOut = 0.5;
	
    private static DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) {
        return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build();
    }
    
	private static final int VECTOR_SIZE = 200;
	private static Map<String, INDArray> vocabulary;

	public static void main(String[] args) throws IOException {
        // Gets Path to Text file
		try(BufferedReader reader=new BufferedReader(new InputStreamReader(
				BasicRNNExample8.class.getResourceAsStream("/rnn/poems0.txt"), StandardCharsets.UTF_8)))
		{
			while(true)
			{
				String line=reader.readLine();
				if (line==null)
					break;
				LEARNSTRING.add(line.toCharArray());
			}
		}
		
		// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
		LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
		for (char[] str : LEARNSTRING)
			for (char c: str)
				LEARNSTRING_CHARS.add(c);
		LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);
		System.out.println("vocabulary size: "+LEARNSTRING_CHARS_LIST.size());
		System.out.println("sentences: "+LEARNSTRING.size());

		vocabulary=new HashMap<>(LEARNSTRING_CHARS_LIST.size());
        List<String> words=new ArrayList<>(vocabulary.size());
        //产生随机WORD2VEC
        LEARNSTRING_CHARS_LIST.forEach((ch)->{
			words.add(String.valueOf(ch));
			vocabulary.put(String.valueOf(ch), Nd4j.rand(Nd4j.create(VECTOR_SIZE)));
		});

		// some common parameters
		NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
		builder.seed(123);
		builder.biasInit(0);
		builder.miniBatch(false);
		builder.updater(new RmsProp(0.002));
		builder.weightInit(WeightInit.XAVIER);

		ListBuilder listBuilder = builder.list();

		// first difference, for rnns we need to use LSTM.Builder
		for (int i = 0; i < HIDDEN_LAYER_CONT-1; i++) {
			LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
			hiddenLayerBuilder.nIn(i == 0 ? VECTOR_SIZE : HIDDEN_LAYER_WIDTH);
			hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
			// adopted activation function from LSTMCharModellingExample
			// seems to work well with RNNs
			hiddenLayerBuilder.activation(Activation.TANH);
			listBuilder.layer(i, hiddenLayerBuilder.build());
		}
		listBuilder.layer(HIDDEN_LAYER_CONT-1, fullyConnected("ffn1", HIDDEN_LAYER_WIDTH, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)));

		// we need to use RnnOutputLayer for our RNN
		RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
		// softmax normalizes the output neurons, the sum of all outputs is 1
		// this is required for our sampleFromDistribution-function
		outputLayerBuilder.activation(Activation.SOFTMAX);
		outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
		outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
		listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());

		// finish builder
		listBuilder.pretrain(false);
		listBuilder.backprop(true);

		// create network
		MultiLayerConfiguration conf = listBuilder.build();
		MultiLayerNetwork net = new MultiLayerNetwork(conf);
		net.init();
		net.setListeners(new ScoreIterationListener(LEARNSTRING.size()));

		/*
		 * CREATE OUR TRAINING DATA
		 */
		// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
		// SEQUENCE_POSITION
		List<DataSet> inputs=new ArrayList<>();
		for (char[] str: LEARNSTRING)
		{
			INDArray input = Nd4j.zeros(1, VECTOR_SIZE, str.length);
			INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
			// loop through our sample-sentence
			int samplePos = 0;
			for (char currentChar : str) {
				// small hack: when currentChar is the last, take the first char as
				// nextChar - not really required. Added to this hack by adding a starter first character.
				char nextChar = str[(samplePos + 1) % (str.length)];
				// input neuron for current-char is 1 at "samplePos"
				INDArrayIndex[] indArrayIndexs = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(samplePos)};
				INDArray element = vocabulary.get(String.valueOf((char)currentChar));
				input.put(indArrayIndexs, 
						element);
				// output neuron for next-char is 1 at "samplePos"
				labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
				samplePos++;
			}
			DataSet trainingData = new DataSet(input, labels);
			inputs.add(trainingData);
		}

		// some epochs
		for (int epoch = 0; epoch < 500; epoch++) {

			for (DataSet trainingData: inputs)
			{
				// train the data
				net.fit(trainingData);
	
				// clear current stance from the last example
				net.rnnClearPreviousState();
			}

			if ((epoch&3)!=3)
				continue;
			System.out.println("Epoch " + epoch);

			for (char str: new char[]{'春','夏','秋','冬','花','月'})
			{
				System.out.print(str);
				// put the first character into the rrn as an initialisation
				INDArray testInit = vocabulary.get(String.valueOf(str));
	
				// run one step -> IMPORTANT: rnnTimeStep() must be called, not
				// output()
				// the output shows what the net thinks what should come next
				INDArray output = net.rnnTimeStep(testInit);
	
				// now the net should guess LEARNSTRING.length more characters
	            for (int i=0;i<23;i++) {
	
	                // first process the last output of the network to a concrete
	                // neuron, the neuron with the highest output has the highest
	                // chance to get chosen
	                int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);
	
	                // print the chosen output
	                Character ch = LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx);
					System.out.print(ch);
	
	                // use the last output as input
	                INDArray nextInput = vocabulary.get(String.valueOf(ch));
	                //if (nextInput==null)
	                //	nextInput=Nd4j.zeros(VECTOR_SIZE);
	                output = net.rnnTimeStep(nextInput);
	
	            }
	            System.out.println();
	            net.rnnClearPreviousState();
			}
		}
	}
}

你可能感兴趣的:(深度学习)