在一次培训中有一个深度学习写诗的例子,觉得很不错,适合入门,加上部分个人理解记录下来。
使用深度学习写诗可以使用RNN/LSTM等具有记忆能力的神经网络结构实现。通过读取提供的诗样本来实现诗结构的记忆,然后通过类似默写(由上一个字符推测下一个字符)的方式来训练写诗能力。
通过dl4j的一个例子来训练英文字符串的记忆,测试结果是训练后的默写情况。完整代码地址:https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/basic/BasicRNNExample.java
训练代码稍作修改并添加部分注释:
public class BasicRNNExample {
// define a sentence to learn.
// Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
private static final char[] LEARNSTRING = "Der Cottbuser Postkutscher putzt den Cottbuser Postkutschkasten.".toCharArray();
// a list of all possible characters
private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();
// RNN dimensions
private static final int HIDDEN_LAYER_WIDTH = 50;//隐藏层节点数
private static final int HIDDEN_LAYER_CONT = 2;//隐藏层数
//private static final Random r = new Random(7894);
public static void main(String[] args) {
// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
//获取LEARNSTRING中存在的字符,去掉重复,也就是训练中可能预测到的字符
LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
for (char c : LEARNSTRING)
LEARNSTRING_CHARS.add(c);
LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);
// some common parameters
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.seed(123);
builder.biasInit(0);//初始化偏移量为0
builder.miniBatch(false);
builder.updater(new RmsProp(0.001));//设置学习率
builder.weightInit(WeightInit.XAVIER);//初始化权重
ListBuilder listBuilder = builder.list();
//定义隐藏层
// first difference, for rnns we need to use LSTM.Builder
for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);//定义输入节点数
hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);//定义输出节点数
// adopted activation function from LSTMCharModellingExample
// seems to work well with RNNs
hiddenLayerBuilder.activation(Activation.TANH);//设置激活函数
listBuilder.layer(i, hiddenLayerBuilder.build());
}
//定义输出层
// we need to use RnnOutputLayer for our RNN
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
// softmax normalizes the output neurons, the sum of all outputs is 1
// this is required for our sampleFromDistribution-function
outputLayerBuilder.activation(Activation.SOFTMAX);
outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
// finish builder
listBuilder.pretrain(false);//预训练
listBuilder.backprop(true);//反向传播
// create network
MultiLayerConfiguration conf = listBuilder.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(10));
/*
* CREATE OUR TRAINING DATA
*/
// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
// SEQUENCE_POSITION
INDArray input = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length);
INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length);
// loop through our sample-sentence
int samplePos = 0;
for (char currentChar : LEARNSTRING) {
// small hack: when currentChar is the last, take the first char as
// nextChar - not really required. Added to this hack by adding a starter first character.
char nextChar = LEARNSTRING[(samplePos + 1) % (LEARNSTRING.length)];
// input neuron for current-char is 1 at "samplePos"
input.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(currentChar), samplePos }, 1);
// output neuron for next-char is 1 at "samplePos"
//以当前字符的下一个作为标签,即通过当前字符训练记忆下一个字符
labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
samplePos++;
}
DataSet trainingData = new DataSet(input, labels);
// some epochs
for (int epoch = 0; epoch < 2000; epoch++) {
if ((epoch%20)==19)
System.out.println("Epoch " + (epoch+1));
// train the data
net.fit(trainingData);
// clear current stance from the last example
net.rnnClearPreviousState();
if ((epoch%20)!=19)
continue;
//以下是测试训练后的默写情况
System.out.print(LEARNSTRING[0]);
// put the first character into the rrn as an initialisation
INDArray testInit = Nd4j.zeros(1,LEARNSTRING_CHARS_LIST.size(), 1);
testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(LEARNSTRING[0]), 1);
// run one step -> IMPORTANT: rnnTimeStep() must be called, not
// output()
// the output shows what the net thinks what should come next
INDArray output = net.rnnTimeStep(testInit);
// now the net should guess LEARNSTRING.length more characters
for (int n=1;n<LEARNSTRING.length;n++) {
// first process the last output of the network to a concrete
// neuron, the neuron with the highest output has the highest
// chance to get chosen
int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);
// print the chosen output
System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx));
// use the last output as input
INDArray nextInput = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), 1);
nextInput.putScalar(sampledCharacterIdx, 1);
output = net.rnnTimeStep(nextInput);
}
System.out.print("\n");
System.out.println(LEARNSTRING);
}
}
}
把英文字符串改为中文测试一下试试,如改为:
“*我们来试试用RNN记忆一段文字。”,字符串前添加一些特殊字符以增加训练效果,具体要根据字符串长度调整一些训练参数。也可以多段文字代替一段文字测试试试。
训练多行记忆测试:
public class BasicRNNExample3 {
// define a sentence to learn.
// Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
private static final char[][] LEARNSTRING = new char[][]{
"我们来试试用RNN记忆一段文字。".toCharArray(),
"然后我们记忆另一段文字试试。".toCharArray()
//, "我们能不能记住第三段文字呢。".toCharArray()//试试添加第三行训练会出现什么情况,重点是本行和第一行“我们”是重复的
};
// a list of all possible characters
private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();
// RNN dimensions
private static final int HIDDEN_LAYER_WIDTH = 50;
private static final int HIDDEN_LAYER_CONT = 2;
public static void main(String[] args) {
// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
for (char[] str : LEARNSTRING)
for (char c: str)
LEARNSTRING_CHARS.add(c);
LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);
// some common parameters
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.seed(123);
builder.biasInit(0);
builder.miniBatch(false);
builder.updater(new RmsProp(0.002));
builder.weightInit(WeightInit.XAVIER);
ListBuilder listBuilder = builder.list();
// first difference, for rnns we need to use LSTM.Builder
for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);
hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
// adopted activation function from LSTMCharModellingExample
// seems to work well with RNNs
hiddenLayerBuilder.activation(Activation.TANH);
listBuilder.layer(i, hiddenLayerBuilder.build());
}
// we need to use RnnOutputLayer for our RNN
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
// softmax normalizes the output neurons, the sum of all outputs is 1
// this is required for our sampleFromDistribution-function
outputLayerBuilder.activation(Activation.SOFTMAX);
outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
// finish builder
listBuilder.pretrain(false);
listBuilder.backprop(true);
// create network
MultiLayerConfiguration conf = listBuilder.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(5));
/*
* CREATE OUR TRAINING DATA
*/
// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
// SEQUENCE_POSITION
List<DataSet> inputs=new ArrayList<>();
for (char[] str: LEARNSTRING)
{
INDArray input = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
// loop through our sample-sentence
int samplePos = 0;
for (char currentChar : str) {
// small hack: when currentChar is the last, take the first char as
// nextChar - not really required. Added to this hack by adding a starter first character.
char nextChar = str[(samplePos + 1) % (str.length)];
// input neuron for current-char is 1 at "samplePos"
input.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(currentChar), samplePos }, 1);
// output neuron for next-char is 1 at "samplePos"
labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
samplePos++;
}
DataSet trainingData = new DataSet(input, labels);
inputs.add(trainingData);
}
// some epochs
for (int epoch = 0; epoch < 200; epoch++) {
//System.out.println("Epoch " + epoch);
for (DataSet trainingData: inputs)
{
// train the data
net.fit(trainingData);
// clear current stance from the last example
net.rnnClearPreviousState();
}
if ((epoch&3)==3)
{
System.out.println("Epoch " + epoch);
int right=0, total=0;
for (char[] str: LEARNSTRING)
{
total+=str.length-1;
System.out.print(str[0]);
// put the first character into the rrn as an initialisation
INDArray testInit = Nd4j.zeros(1,LEARNSTRING_CHARS_LIST.size(), 1);
testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(str[0]), 1);
// run one step -> IMPORTANT: rnnTimeStep() must be called, not
// output()
// the output shows what the net thinks what should come next
INDArray output = net.rnnTimeStep(testInit);
// now the net should guess LEARNSTRING.length more characters
for (int i=0;i<str.length-1;i++) {
// first process the last output of the network to a concrete
// neuron, the neuron with the highest output has the highest
// chance to get chosen
int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);
// print the chosen output
System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx));
if (str[(i+1)%str.length]==LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx))
right++;
// use the last output as input
INDArray nextInput = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), 1);
nextInput.putScalar(sampledCharacterIdx, 1);
output = net.rnnTimeStep(nextInput);
}
System.out.println();
net.rnnClearPreviousState();
}
System.out.println(right+"/"+total);
}
}
}
}
这里以训练写5言诗为例
1.准备训练用的5言诗,放在/rnn/poems0.txt文档
寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。
晚霞聊自怡,初晴弥可喜。日晃百花色,风动千林翠。池鱼跃不同,园鸟声还异。寄言博通者,知予物外志。
一朝春夏改,隔夜鸟花迁。阴阳深浅叶,晓夕重轻烟。哢莺犹响殿,横丝正网天。珮高兰影接,绶细草纹连。
碧鳞惊棹侧,玄燕舞檐前。何必汾阳处,始复有山泉。
夏律昨留灰,秋箭今移晷。峨嵋岫初出,洞庭波渐起。桂白发幽岩,菊黄开灞涘。运流方可叹,含毫属微理。
寒惊蓟门叶,秋发小山枝。松阴背日转,竹影避风移。提壶菊花岸,高兴芙蓉池。欲知凉气早,巢空燕不窥。
爽气浮丹阙,秋光澹紫宫。衣碎荷疏影,花明菊点丛。袍轻低草露,盖侧舞松风。散岫飘云叶,迷路飞烟鸿。
砌冷兰凋佩,闺寒树陨桐。别鹤栖琴里,离猿啼峡中。落野飞星箭,弦虚半月弓。芳菲夕雾起,暮色满房栊。
山亭秋色满,岩牖凉风度。疏兰尚染烟,残菊犹承露。古石衣新苔,新巢封古树。历览情无极,咫尺轮光暮。
碧原开雾隰,绮岭峻霞城。烟峰高下翠,日浪浅深明。斑红妆蕊树,圆青压溜荆。迹岩劳傅想,窥野访莘情。巨川何以济,舟楫伫时英。
韶光开令序,淑气动芳年。驻辇华林侧,高宴柏梁前。紫庭文珮满,丹墀衮绂连。九夷簉瑶席,五狄列琼筵。
娱宾歌湛露,广乐奏钧天。清尊浮绿醑,雅曲韵朱弦。粤余君万国,还惭抚八埏。庶几保贞固,虚己厉求贤。
烈烈寒风起,惨惨飞云浮。霜浓凝广隰,冰厚结清流。金鞍移上苑,玉勒骋平畴。旌旗四望合,罝罗一面求。
楚踣争兕殪,秦亡角鹿愁。兽忙投密树,鸿惊起砾洲。骑敛原尘静,戈回岭日收。心非洛汭逸,意在渭滨游。禽荒非所乐,抚辔更招忧。
披襟眺沧海,凭轼玩春芳。积流横地纪,疏派引天潢。仙气凝三岭,和风扇八荒。拂潮云布色,穿浪日舒光。
照岸花分彩,迷云雁断行。怀卑运深广,持满守灵长。有形非易测,无源讵可量。洪涛经变野,翠岛屡成桑。
冻云宵遍岭,素雪晓凝华。入牖千重碎,迎风一半斜。不妆空散粉,无树独飘花。萦空惭夕照,破彩谢晨霞。
暮景斜芳殿,年华丽绮宫。寒辞去冬雪,暖带入春风。阶馥舒梅素,盘花卷烛红。共欢新故岁,迎送一宵中。
岁阴穷暮纪,献节启新芳。冬尽今宵促,年开明日长。冰消出镜水,梅散入风香。对此欢终宴,倾壶待曙光。
和气吹绿野,梅雨洒芳田。新流添旧涧,宿雾足朝烟。雁湿行无次,花沾色更鲜。对此欣登岁,披襟弄五弦。
翠楼含晓雾,莲峰带晚云。玉叶依岩聚,金枝触石分。横天结阵影,逐吹起罗文。非得阳台下,空将惑楚君。
2.在上面训练代码的基础上更改来训练写诗
public class BasicRNNExample5 {
// define a sentence to learn.
// Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
private static final List<char[]> LEARNSTRING = new ArrayList<>();
// a list of all possible characters
private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();
// RNN dimensions
private static final int HIDDEN_LAYER_WIDTH = 50;
private static final int HIDDEN_LAYER_CONT = 2;
public static void main(String[] args) throws IOException {
// Gets Path to Text file
try(BufferedReader reader=new BufferedReader(new InputStreamReader(
BasicRNNExample5.class.getResourceAsStream("/rnn/poems0.txt"), StandardCharsets.UTF_8)))
{
while(true)
{
String line=reader.readLine();
if (line==null)
break;
LEARNSTRING.add(line.toCharArray());
}
}
// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
for (char[] str : LEARNSTRING)
for (char c: str)
LEARNSTRING_CHARS.add(c);
LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);
System.out.println("vocabulary size: "+LEARNSTRING_CHARS_LIST.size());
System.out.println("sentences: "+LEARNSTRING.size());
// some common parameters
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.seed(123);
builder.biasInit(0);
builder.miniBatch(false);
builder.updater(new RmsProp(0.002));
builder.weightInit(WeightInit.XAVIER);
ListBuilder listBuilder = builder.list();
// first difference, for rnns we need to use LSTM.Builder
for (int i = 0; i < HIDDEN_LAYER_CONT; i++) {
LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
hiddenLayerBuilder.nIn(i == 0 ? LEARNSTRING_CHARS.size() : HIDDEN_LAYER_WIDTH);
hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
// adopted activation function from LSTMCharModellingExample
// seems to work well with RNNs
hiddenLayerBuilder.activation(Activation.TANH);
listBuilder.layer(i, hiddenLayerBuilder.build());
}
// we need to use RnnOutputLayer for our RNN
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
// softmax normalizes the output neurons, the sum of all outputs is 1
// this is required for our sampleFromDistribution-function
outputLayerBuilder.activation(Activation.SOFTMAX);
outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
// finish builder
listBuilder.pretrain(false);
listBuilder.backprop(true);
// create network
MultiLayerConfiguration conf = listBuilder.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(LEARNSTRING.size()));
/*
* CREATE OUR TRAINING DATA
*/
// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
// SEQUENCE_POSITION
List<DataSet> inputs=new ArrayList<>();
for (char[] str: LEARNSTRING)
{
INDArray input = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
// loop through our sample-sentence
int samplePos = 0;
for (char currentChar : str) {
// small hack: when currentChar is the last, take the first char as
// nextChar - not really required. Added to this hack by adding a starter first character.
char nextChar = str[(samplePos + 1) % (str.length)];
// input neuron for current-char is 1 at "samplePos"
input.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(currentChar), samplePos }, 1);
// output neuron for next-char is 1 at "samplePos"
labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
samplePos++;
}
DataSet trainingData = new DataSet(input, labels);
inputs.add(trainingData);
}
// some epochs
for (int epoch = 0; epoch < 200; epoch++) {
for (DataSet trainingData: inputs)
{
// train the data
net.fit(trainingData);
// clear current stance from the last example
net.rnnClearPreviousState();
}
if ((epoch&3)!=3)
continue;
System.out.println("Epoch " + epoch);
//以'春','夏','秋','冬','花','月'开头作诗,每行24个字符
for (char str: new char[]{'春','夏','秋','冬','花','月'})
{
System.out.print(str);
// put the first character into the rrn as an initialisation
INDArray testInit = Nd4j.zeros(1,LEARNSTRING_CHARS_LIST.size(), 1);
testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(str), 1);
// run one step -> IMPORTANT: rnnTimeStep() must be called, not
// output()
// the output shows what the net thinks what should come next
INDArray output = net.rnnTimeStep(testInit);
// now the net should guess LEARNSTRING.length more characters
for (int i=0;i<23;i++) {
// first process the last output of the network to a concrete
// neuron, the neuron with the highest output has the highest
// chance to get chosen
int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);
// print the chosen output
System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx));
// use the last output as input
INDArray nextInput = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), 1);
nextInput.putScalar(sampledCharacterIdx, 1);
output = net.rnnTimeStep(nextInput);
}
System.out.println();
net.rnnClearPreviousState();
}
}
}
}
从运行的效果来看,并非训练得越多,作诗效果越好。中间能出来一些意向不到的好句子,但是充分训练后,记忆的效果出来了,变成默写了。
可以引入一些手段防止默写,如:
WORD2VEC+全连接,效果不好;
随机WORD2VEC,效果有所提升,代码参考如下:
public class BasicRNNExample8 {
// define a sentence to learn.
// Add a special character at the beginning so the RNN learns the complete string and ends with the marker.
private static final List<char[]> LEARNSTRING = new ArrayList<>();
// a list of all possible characters
private static final List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();
// RNN dimensions
private static final int HIDDEN_LAYER_WIDTH = 50;
private static final int HIDDEN_LAYER_CONT = 3;
private static final double nonZeroBias = 1;
private static final double dropOut = 0.5;
private static DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) {
return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build();
}
private static final int VECTOR_SIZE = 200;
private static Map<String, INDArray> vocabulary;
public static void main(String[] args) throws IOException {
// Gets Path to Text file
try(BufferedReader reader=new BufferedReader(new InputStreamReader(
BasicRNNExample8.class.getResourceAsStream("/rnn/poems0.txt"), StandardCharsets.UTF_8)))
{
while(true)
{
String line=reader.readLine();
if (line==null)
break;
LEARNSTRING.add(line.toCharArray());
}
}
// create a dedicated list of possible chars in LEARNSTRING_CHARS_LIST
LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
for (char[] str : LEARNSTRING)
for (char c: str)
LEARNSTRING_CHARS.add(c);
LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);
System.out.println("vocabulary size: "+LEARNSTRING_CHARS_LIST.size());
System.out.println("sentences: "+LEARNSTRING.size());
vocabulary=new HashMap<>(LEARNSTRING_CHARS_LIST.size());
List<String> words=new ArrayList<>(vocabulary.size());
//产生随机WORD2VEC
LEARNSTRING_CHARS_LIST.forEach((ch)->{
words.add(String.valueOf(ch));
vocabulary.put(String.valueOf(ch), Nd4j.rand(Nd4j.create(VECTOR_SIZE)));
});
// some common parameters
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.seed(123);
builder.biasInit(0);
builder.miniBatch(false);
builder.updater(new RmsProp(0.002));
builder.weightInit(WeightInit.XAVIER);
ListBuilder listBuilder = builder.list();
// first difference, for rnns we need to use LSTM.Builder
for (int i = 0; i < HIDDEN_LAYER_CONT-1; i++) {
LSTM.Builder hiddenLayerBuilder = new LSTM.Builder();
hiddenLayerBuilder.nIn(i == 0 ? VECTOR_SIZE : HIDDEN_LAYER_WIDTH);
hiddenLayerBuilder.nOut(HIDDEN_LAYER_WIDTH);
// adopted activation function from LSTMCharModellingExample
// seems to work well with RNNs
hiddenLayerBuilder.activation(Activation.TANH);
listBuilder.layer(i, hiddenLayerBuilder.build());
}
listBuilder.layer(HIDDEN_LAYER_CONT-1, fullyConnected("ffn1", HIDDEN_LAYER_WIDTH, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)));
// we need to use RnnOutputLayer for our RNN
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
// softmax normalizes the output neurons, the sum of all outputs is 1
// this is required for our sampleFromDistribution-function
outputLayerBuilder.activation(Activation.SOFTMAX);
outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
// finish builder
listBuilder.pretrain(false);
listBuilder.backprop(true);
// create network
MultiLayerConfiguration conf = listBuilder.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(LEARNSTRING.size()));
/*
* CREATE OUR TRAINING DATA
*/
// create input and output arrays: SAMPLE_INDEX, INPUT_NEURON,
// SEQUENCE_POSITION
List<DataSet> inputs=new ArrayList<>();
for (char[] str: LEARNSTRING)
{
INDArray input = Nd4j.zeros(1, VECTOR_SIZE, str.length);
INDArray labels = Nd4j.zeros(1, LEARNSTRING_CHARS_LIST.size(), str.length);
// loop through our sample-sentence
int samplePos = 0;
for (char currentChar : str) {
// small hack: when currentChar is the last, take the first char as
// nextChar - not really required. Added to this hack by adding a starter first character.
char nextChar = str[(samplePos + 1) % (str.length)];
// input neuron for current-char is 1 at "samplePos"
INDArrayIndex[] indArrayIndexs = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(samplePos)};
INDArray element = vocabulary.get(String.valueOf((char)currentChar));
input.put(indArrayIndexs,
element);
// output neuron for next-char is 1 at "samplePos"
labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
samplePos++;
}
DataSet trainingData = new DataSet(input, labels);
inputs.add(trainingData);
}
// some epochs
for (int epoch = 0; epoch < 500; epoch++) {
for (DataSet trainingData: inputs)
{
// train the data
net.fit(trainingData);
// clear current stance from the last example
net.rnnClearPreviousState();
}
if ((epoch&3)!=3)
continue;
System.out.println("Epoch " + epoch);
for (char str: new char[]{'春','夏','秋','冬','花','月'})
{
System.out.print(str);
// put the first character into the rrn as an initialisation
INDArray testInit = vocabulary.get(String.valueOf(str));
// run one step -> IMPORTANT: rnnTimeStep() must be called, not
// output()
// the output shows what the net thinks what should come next
INDArray output = net.rnnTimeStep(testInit);
// now the net should guess LEARNSTRING.length more characters
for (int i=0;i<23;i++) {
// first process the last output of the network to a concrete
// neuron, the neuron with the highest output has the highest
// chance to get chosen
int sampledCharacterIdx = Nd4j.getExecutioner().exec(new IMax(output), 1).getInt(0);
// print the chosen output
Character ch = LEARNSTRING_CHARS_LIST.get(sampledCharacterIdx);
System.out.print(ch);
// use the last output as input
INDArray nextInput = vocabulary.get(String.valueOf(ch));
//if (nextInput==null)
// nextInput=Nd4j.zeros(VECTOR_SIZE);
output = net.rnnTimeStep(nextInput);
}
System.out.println();
net.rnnClearPreviousState();
}
}
}
}