vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = keras.models.Sequential([
keras.layers.Embedding(vocab_size, embedding_dim,
batch_input_shape = [batch_size, None]),
keras.layers.SimpleRNN(units = rnn_units,
return_sequences = True),
keras.layers.Dense(vocab_size)
])
return model
model = build_model(
vocab_size = vocab_size,
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=batch_size)
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (64, None, 256) 16640
_________________________________________________________________
simple_rnn (SimpleRNN) (64, None, 1024) 1311744
_________________________________________________________________
dense (Dense) (64, None, 65) 66625
=================================================================
Total params: 1,395,009
Trainable params: 1,395,009
Non-trainable params: 0
_________________________________________________________________
for input_example_batch, target_example_batch in seq_dataset.take(1):
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape)
# 64是batch_size,100是每个句子的长度,65是一个概率分布
(64, 100, 65)
# 基于输出的65,进行随机采样
# 当选取概率最大的值时,被称为贪心策略,当随机采样时,为随机策略
# logits:在分类任务中,softmax之前的值就为logits
sample_indices = tf.random.categorical(
logits = example_batch_predictions[0], num_samples=1)
# (100, 65) -> (100, 1)
print(sample_indices)
# 变成向量
sample_indices = tf.squeeze(sample_indices, axis = -1)
print(sample_indices)
tf.Tensor(
[[50]
[47]
[16]
[41]
[41]
[15]
[ 0]
[58]
[48]
[58]
[62]
[22]
[48]
[36]
[36]
[44]
[45]
[12]
[ 7]
[31]
[22]
[53]
[32]
[44]
[26]
[17]
[ 1]
[ 1]
[31]
[ 5]
[35]
[22]
[64]
[32]
[15]
[25]
[60]
[12]
[ 3]
[28]
[11]
[24]
[28]
[ 7]
[39]
[56]
[18]
[26]
[55]
[39]
[10]
[48]
[28]
[53]
[43]
[17]
[48]
[27]
[23]
[55]
[ 5]
[49]
[64]
[ 6]
[11]
[ 4]
[32]
[ 8]
[23]
[46]
[18]
[ 5]
[64]
[52]
[44]
[26]
[16]
[59]
[37]
[15]
[27]
[41]
[16]
[38]
[ 6]
[20]
[42]
[62]
[24]
[62]
[14]
[42]
[12]
[14]
[12]
[48]
[ 5]
[45]
[42]
[25]], shape=(100, 1), dtype=int64)
tf.Tensor(
[50 47 16 41 41 15 0 58 48 58 62 22 48 36 36 44 45 12 7 31 22 53 32 44
26 17 1 1 31 5 35 22 64 32 15 25 60 12 3 28 11 24 28 7 39 56 18 26
55 39 10 48 28 53 43 17 48 27 23 55 5 49 64 6 11 4 32 8 23 46 18 5
64 52 44 26 16 59 37 15 27 41 16 38 6 20 42 62 24 62 14 42 12 14 12 48
5 45 42 25], shape=(100,), dtype=int64)
print("Input:", repr("".join(idx2char[input_example_batch[0]])))
print()
print("Output:", repr("".join(idx2char[target_example_batch[0]])))
print()
print("Predictions:", repr("".join(idx2char[sample_indices])))
Input: 'ction, sir,\nEven by your own.\n\nAUFIDIUS:\nI cannot help it now,\nUnless, by using means, I lame the fo'
Output: 'tion, sir,\nEven by your own.\n\nAUFIDIUS:\nI cannot help it now,\nUnless, by using means, I lame the foo'
Predictions: "liDccC\ntjtxJjXXfg?-SJoTfNE S'WJzTCMv?$P;LP-arFNqa:jPoeEjOKq'kz,;&T.KhF'znfNDuYCOcDZ,HdxLxBd?B?j'gdM"
# 自定义损失函数
def loss(labels, logits):
return keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
model.compile(optimizer= 'adam', loss = loss)
example_loss = loss(target_example_batch, example_batch_predictions)
print(example_loss.shape)
print(example_loss.numpy().mean())
(64, 100)
4.1735864
output_dir = "./text_generation_checkpoints"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epoch}')
checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
epochs =50
history = model.fit(seq_dataset, epochs=epochs, callbacks=[checkpoint_callback])
Train for 172 steps
Epoch 1/50
172/172 [==============================] - 83s 485ms/step - loss: 1.3785
Epoch 2/50
172/172 [==============================] - 102s 591ms/step - loss: 1.3592
Epoch 3/50
172/172 [==============================] - 97s 565ms/step - loss: 1.3413
Epoch 4/50
172/172 [==============================] - 97s 563ms/step - loss: 1.3255
Epoch 5/50
172/172 [==============================] - 93s 543ms/step - loss: 1.3118
Epoch 6/50
172/172 [==============================] - 97s 562ms/step - loss: 1.2964
Epoch 7/50
172/172 [==============================] - 99s 577ms/step - loss: 1.2826
Epoch 8/50
172/172 [==============================] - 96s 559ms/step - loss: 1.2695
Epoch 9/50
172/172 [==============================] - 95s 550ms/step - loss: 1.2572
Epoch 10/50
172/172 [==============================] - 93s 541ms/step - loss: 1.2468
Epoch 11/50
172/172 [==============================] - 93s 540ms/step - loss: 1.2338
Epoch 12/50
172/172 [==============================] - 91s 531ms/step - loss: 1.2217
Epoch 13/50
172/172 [==============================] - 95s 551ms/step - loss: 1.2104
Epoch 14/50
172/172 [==============================] - 95s 551ms/step - loss: 1.1972
Epoch 15/50
172/172 [==============================] - 95s 554ms/step - loss: 1.1882
Epoch 16/50
172/172 [==============================] - 90s 526ms/step - loss: 1.1761
Epoch 17/50
172/172 [==============================] - 93s 542ms/step - loss: 1.1636
Epoch 18/50
172/172 [==============================] - 97s 564ms/step - loss: 1.1555
Epoch 19/50
172/172 [==============================] - 94s 548ms/step - loss: 1.1408
Epoch 20/50
172/172 [==============================] - 93s 540ms/step - loss: 1.1322
Epoch 21/50
172/172 [==============================] - 94s 549ms/step - loss: 1.1215
Epoch 22/50
172/172 [==============================] - 95s 551ms/step - loss: 1.1115
Epoch 23/50
172/172 [==============================] - 95s 551ms/step - loss: 1.0999
Epoch 24/50
172/172 [==============================] - 96s 555ms/step - loss: 1.0902
Epoch 25/50
172/172 [==============================] - 94s 545ms/step - loss: 1.0794
Epoch 26/50
172/172 [==============================] - 97s 563ms/step - loss: 1.0724
Epoch 27/50
172/172 [==============================] - 94s 548ms/step - loss: 1.0603
Epoch 28/50
172/172 [==============================] - 96s 557ms/step - loss: 1.0528
Epoch 29/50
172/172 [==============================] - 95s 550ms/step - loss: 1.0471
Epoch 30/50
172/172 [==============================] - 99s 576ms/step - loss: 1.0338
Epoch 31/50
172/172 [==============================] - 98s 570ms/step - loss: 1.0278
Epoch 32/50
172/172 [==============================] - 97s 567ms/step - loss: 1.0208
Epoch 33/50
172/172 [==============================] - 94s 547ms/step - loss: 1.0127
Epoch 34/50
172/172 [==============================] - 99s 573ms/step - loss: 1.0064
Epoch 35/50
172/172 [==============================] - 99s 573ms/step - loss: 1.0021
Epoch 36/50
172/172 [==============================] - 97s 565ms/step - loss: 0.9938
Epoch 37/50
172/172 [==============================] - 96s 559ms/step - loss: 0.9892
Epoch 38/50
172/172 [==============================] - 100s 581ms/step - loss: 0.9835
Epoch 39/50
172/172 [==============================] - 96s 557ms/step - loss: 0.9790
Epoch 40/50
172/172 [==============================] - 98s 571ms/step - loss: 0.9725
Epoch 41/50
172/172 [==============================] - 109s 636ms/step - loss: 0.9690
Epoch 42/50
172/172 [==============================] - 102s 592ms/step - loss: 0.9675
Epoch 43/50
172/172 [==============================] - 101s 585ms/step - loss: 0.9615
Epoch 44/50
172/172 [==============================] - 100s 581ms/step - loss: 0.9564
Epoch 45/50
172/172 [==============================] - 101s 585ms/step - loss: 0.9571
Epoch 46/50
172/172 [==============================] - 99s 574ms/step - loss: 0.9521
Epoch 47/50
172/172 [==============================] - 101s 589ms/step - loss: 0.9527
Epoch 48/50
172/172 [==============================] - 100s 581ms/step - loss: 0.9459
Epoch 49/50
172/172 [==============================] - 101s 588ms/step - loss: 0.9439
Epoch 50/50
172/172 [==============================] - 101s 590ms/step - loss: 0.9455
tf.train.latest_checkpoint(output_dir)
'./text_generation_checkpoints\\ckpt_10'
# 载入训练好的模型
model2 = build_model(vocab_size,
embedding_dim,
rnn_units,
batch_size = 1)
# 载入权重
model2.load_weights(tf.train.latest_checkpoint(output_dir))
# 设置输入的size
model2.build(tf.TensorShape([1, None]))
# 文本生成的流程
# 初始是一个字符串char -> A,
# A -> model -> b
# A.append(b) -> Ab
# Ab -> model -> c
# Ab.append(c) -> Abc
# Abc -> model -> ...
model2.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_3 (Embedding) (1, None, 256) 16640
_________________________________________________________________
simple_rnn_3 (SimpleRNN) (1, None, 1024) 1311744
_________________________________________________________________
dense_3 (Dense) (1, None, 65) 66625
=================================================================
Total params: 1,395,009
Trainable params: 1,395,009
Non-trainable params: 0
_________________________________________________________________
# 文本生成
def generate_text(model, start_string, num_generate = 1000):
# 变成id文本,1维
input_eval = [char2idx[ch] for ch in start_string]
# 维度扩展,变成2维
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
model.reset_states()
for _ in range(num_generate):
# predictions:[batch_size, input_eval_len, vocab_size]
predictions = model(input_eval)
# 降低维度
# predictions: [input_eval_len, vocab_size]
predictions = tf.squeeze(predictions, 0)
# predicted_idds: [input_eval_len, 1]
predicted_id = tf.random.categorical(
predictions, num_samples = 1)[-1, 0].numpy()
text_generated.append(idx2char[predicted_id])
input_eval = tf.expand_dims([predicted_id], 0)
return start_string + ''.join(text_generated)
new_text = generate_text(model2, 'All: ')
print(new_text)
All: ngs llou,
TINIf,
F t m ndaigome tooure d o ailowe and se?
Forerinenowiliounome! thetould kereriret mo is,
IORDIst, tss teey bante thes g'l
pofamyounge hu,
TE:
Thay yershand ome bar, bo t d bame pat ad l fidshey s ha, he K:
Whe
Cleel? pe bred t d mag our ts bulange y? d.
I t.
Timoure;
IXESTy yo d mol m HAUCHARD oucixan s, me dey tr mery male u scostivis f thy.
FFidir my ik thif My that m theve CII kersor, s agrs,
AUSoustoutithed ga, likimilerd nde uthigan t thatouthe oo bot, t buicithe s m't buche t he se t,
TESAUCHASA:
CO:
IOf hey nghat urow lll allie t! hire? w'ss lourally apr w, ndean le
MENGrff ilayo gh Icade pe he my
Fouser'son yoman;
FI:
Thilllildize: bas.
Thishe.
I ththes e and f ay tesseshicu, teade omowonanouss
ENCRIO:
VI bupootsie, ve HOMes he:
WAs Ay t t e pigulke or coungh ton touncancousir ANAnes banathie cen thy welfanghavere th vifo d s w soure iciee! bjur's,
A wanicomer t w'st be aint omofecofeye myowhe;
LYO:
While: burenowite howhame t,
Ofreer the tay?
BE:
HAs duc
可以看出效果不怎么样。