要使用 TensorFlow Datasets (TFDS) 来训练一个文本摘要模型,可以选择一个包含文章和摘要的数据集,例如 CNN/DailyMail 数据集。
这个数据集通常用于训练和评估文本摘要模型。
以下是使用 TFDS 加载数据集并训练一个简单的序列到序列 (seq2seq) 模型的过程。
首先,确保安装了 TensorFlow Datasets:
pip install tensorflow tensorflow-datasets
然后,以下是训练文本摘要模型的完整代码:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
from tensorflow.keras.layers import TextVectorization, Embedding, LSTM, Dense
# 加载 CNN/DailyMail 数据集
data, info = tfds.load('cnn_dailymail', with_info=True, as_supervised=True)
train_data, val_data = data['train'], data['validation']
# 为了加快演示,我们将只使用一小部分数据
train_data = train_data.take(5000)
val_data = val_data.take(1000)
# 定义文本向量化和序列长度
sequence_length = 512
vocab_size = 20000
vectorize_layer = TextVectorization(max_tokens=vocab_size, output_mode='int', output_sequence_length=sequence_length)
# 准备数据集
def prepare_dataset(data):
articles = data.map(lambda article, summary: article)
summaries = data.map(lambda article, summary: summary)
vectorize_layer.adapt(articles)
vectorized_articles = articles.map(lambda x: vectorize_layer(x))
vectorized_summaries = summaries.map(lambda x: vectorize_layer(x))
dataset = tf.data.Dataset.zip((vectorized_articles, vectorized_summaries)).batch(32).prefetch(tf.data.AUTOTUNE)
return dataset
train_dataset = prepare_dataset(train_data)
val_dataset = prepare_dataset(val_data)
# 构建一个简单的 seq2seq 模型
embedding_dim = 128
lstm_units = 256
# 编码器
encoder_inputs = tf.keras.Input(shape=(None,), dtype='int64')
encoder_embedding = Embedding(vocab_size, embedding_dim)(encoder_inputs)
_, state_h, state_c = LSTM(lstm_units, return_state=True)(encoder_embedding)
encoder_states = [state_h, state_c]
# 解码器
decoder_inputs = tf.keras.Input(shape=(None,), dtype='int64')
decoder_embedding = Embedding(vocab_size, embedding_dim)(decoder_inputs)
decoder_lstm = LSTM(lstm_units, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_embedding, initial_state=encoder_states)
decoder_dense = Dense(vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
# 定义模型
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
# 编译模型
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(train_dataset, epochs=10, validation_data=val_dataset)
# 使用模型进行文本摘要
def summarize_text(text, model, vectorize_layer):
vectorized_text = vectorize_layer(tf.convert_to_tensor([text]))
summary = tf.constant([vectorize_layer.vocab_size - 1], dtype=tf.int64) # 使用序列结束标记开始
for _ in range(sequence_length):
predictions = model.predict([vectorized_text, tf.expand_dims(summary, 0)])
predicted_id = tf.argmax(predictions[0, -1, :])
if predicted_id == 0:
break
summary = tf.concat([summary, [predicted_id]], axis=0)
return vectorize_layer.get_vocabulary()[summary.numpy()]
# 测试摘要
for article, summary in val_data.take(1):
print('原始文章:', article.numpy().decode('utf-8'))
print('真实摘要:', summary.numpy().decode('utf-8'))
predicted_summary = summarize_text(article.numpy().decode('utf-8'), model, vectorize_layer)
print('预测摘要:', ' '.join(predicted_summary))
这段代码做了如下几件事情:
summarize_text
,用于生成文本摘要。针对 CNN/DailyMail 数据集的预处理部分的代码,以及每一步的解释:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import TextVectorization
# 加载 CNN/DailyMail 数据集
data = tfds.load('cnn_dailymail', as_supervised=True)
train_data, val_data = data['train'], data['validation']
# 定义文本向量化的参数
sequence_length = 512
vocab_size = 20000
# 创建一个文本向量化层
vectorize_layer = TextVectorization(
max_tokens=vocab_size, # 设置最大的词汇量
output_mode='int', # 设置输出模式为整数索引
output_sequence_length=sequence_length # 设置输出的序列长度
)
# 准备数据集的函数
def prepare_dataset(data):
# 将数据集分为文章和摘要
articles = data.map(lambda article, summary: article)
summaries = data.map(lambda article, summary: summary)
# 适应文本向量化层,只对文章进行适应以构建词汇表
vectorize_layer.adapt(articles)
# 将文章和摘要映射到整数序列
vectorized_articles = articles.map(lambda x: vectorize_layer(x))
vectorized_summaries = summaries.map(lambda x: vectorize_layer(x))
# 将文章和摘要的整数序列打包成一个新的数据集,并进行批处理和预取
dataset = tf.data.Dataset.zip((vectorized_articles, vectorized_summaries))
dataset = dataset.batch(32).prefetch(tf.data.AUTOTUNE)
return dataset
# 应用预处理函数到训练和验证数据集
train_dataset = prepare_dataset(train_data)
val_dataset = prepare_dataset(val_data)
解释:
加载数据集:
使用 tfds.load
函数加载 CNN/DailyMail 数据集。参数 as_supervised=True
表示我们希望以监督学习的格式加载数据集,即每个数据点都包含输入数据(文章)和标签数据(摘要)。
定义文本向量化参数:
设置序列长度和词汇量的大小。这些参数对于模型处理文本数据非常重要。sequence_length
确定了模型可以处理的最大文章和摘要长度。vocab_size
决定了词汇表的大小,即模型可以识别的不同单词的最大数量。
创建文本向量化层:TextVectorization
层用于将文本转换为整数序列。每个整数都对应词汇表中的一个单词。这一步是将自然语言转换为机器学习模型可以处理的格式的关键步骤。
预处理数据集的函数:prepare_dataset
函数负责将原始文本数据集转换为模型可以使用的格式。它首先将数据集分为文章和摘要,然后使用 vectorize_layer.adapt
方法来适应(即构建)词汇表。随后,它将文章和摘要映射到整数序列。
批处理和预取:batch(32)
方法将数据集划分为大小为 32 的批次,这意味着模型将一次处理 32 篇文章及其相应的摘要。prefetch(tf.data.AUTOTUNE)
方法用于提前准备好接下来的数据批次,这样在模型训练时可以减少 I/O 阻塞,提高训练效率。
应用预处理:prepare_dataset
函数被应用到训练和验证数据集上,这样我们就得到了可以直接用于模型训练和评估的数据集。
这个预处理过程是为了简化示例而设定的,并且假设模型是一个基础的 seq2seq 模型。
在实际应用中,您可能需要更复杂的预处理步骤,例如对文本进行清洗、使用子词分词(subword tokenization)等。
为了使模型训练过程中自动保存最佳模型,我们可以使用 ModelCheckpoint
回调。这个回调会在每个训练周期(epoch)结束时运行,并根据我们指定的条件(如验证集上的损失或准确率)保存模型。下面是如何设置 ModelCheckpoint
回调并将其添加到模型训练中的示例:
首先,导入所需的库并设置 ModelCheckpoint
回调:
from tensorflow.keras.callbacks import ModelCheckpoint
# 设置模型检查点回调,保存最佳模型
checkpoint_path = "seq2seq_checkpoint.ckpt"
checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
save_best_only=True,
monitor='val_loss', # 也可以是 'val_accuracy',取决于你想监控的指标
mode='min', # 如果监控的是 'val_loss',则模式是 'min',即越小越好
verbose=1 # 打印保存模型的信息
)
接着,将 checkpoint_callback
添加到 fit
方法的 callbacks
参数中:
# 训练模型
model.fit(
train_dataset,
epochs=10,
validation_data=val_dataset,
callbacks=[checkpoint_callback] # 添加回调函数
)
现在,模型会在每个训练周期结束时自动检查验证损失,并在出现更低的验证损失时保存权重。save_weights_only=True
表示只保存模型的权重,而不是整个模型。这样可以节省存储空间,但需要在加载权重时重建模型结构。
如果你想在训练后恢复模型的权重,你可以使用以下代码:
# 假设模型结构已经定义并编译,然后加载权重
model.load_weights(checkpoint_path)
完成这些之后,你可以使用 summarize_text
函数对新文章进行摘要,或者评估模型在验证集上的表现。这里是完整的代码,包含了自动保存回调的设置:
# 定义模型检查点回调
checkpoint_path = "seq2seq_checkpoint.ckpt"
checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
save_best_only=True,
monitor='val_loss',
mode='min',
verbose=1
)
# 训练模型
model.fit(
train_dataset,
epochs=10,
validation_data=val_dataset,
callbacks=[checkpoint_callback] # 添加回调函数
)
# 在需要时加载最佳模型权重
model.load_weights(checkpoint_path)
# 使用模型进行文本摘要
# ...