接上一章节
上一章节调用完了adam优化器,这一章节重点介绍AutoSummary类别seq2seq解码器的操作
class AutoSummary(AutoRegressiveDecoder):
继承了之前的AutoRegressiveDecoder的类别
首先我们去查看AutoRegressiveDecoder类的示例
from bert4keras.snippets import AutoRegressiveDecoder
这里自定义的时候调用
autosummary = AutoSummary(
start_id = tokenizer._token_start_id,
end_id = tokenizer._token_end_id,
maxlen = maxlen // 2
)
进入初始化AutoRegressiveDecoder的类别之中
class AutoRegressiveDecoder(object):
"""通用自回归生成模型解码基类
包含beam search和random sample两种策略
"""
def __init__(self, start_id, end_id, maxlen, minlen=1):
self.start_id = start_id
self.end_id = end_id
self.maxlen = maxlen
self.minlen = minlen
self.models = {
}
if start_id is None:
self.first_output_ids = np.empty((1, 0), dtype=int)
else:
self.first_output_ids = np.array([[self.start_id]])
这里的
start_id = 2
end_id = 3
self.maxlen = 512
(不明白这里为什么self.maxlen要除以2)。
由于AutoSummary(AutoRegressiveDecoder)类之中包含的函数众多,所以这里我们需要先看看有哪些函数被调用过
仔细打印对应的输出内容发现,调用了相应的AutoRegressiveDecoder wraps函数内容
AutoRegressiveDecoder warps前面加上了静态函数
使用__init__查看AutoSummary的初始化过程之后,发现了先调用的AutoRegressiveDecoder.wraps函数,然后才调用的初始化过程
调用AutoRegressiveDecoder.wraps函数在定义类别的过程中直接就进行调用了
@AutoRegressiveDecoder.wraps(default_rtype='logits',use_states=True)
这里直接调用AutoRegressiveDecoder中的wraps,因为wraps
之前AutoRegressiveDecoder中的wraps函数就被定义为@staticmethod的静态方法,所以这里可以通过类直接调用
@AutoRegressiveDecoder.wraps(default_rtype='logits', use_states=True)
然后再进行初始化
autosummary = AutoSummary(
start_id = tokenizer._token_start_id,
end_id = tokenizer._token_end_id,
maxlen = maxlen//2
)
后面这些初始化的内容也没有用上,直接开始训练了2个epochs的训练集合数据
train_generator = data_generator(train_data,batch_size)
train_model.fit_generator(
train_generator.forfit(),
steps_per_epoch = len(train_generator),
epochs = epochs,
callbacks = [evaluator]
)
这里感觉不完善,应该是训练完每一个epoch之后就进行预测,保存在测试集合上最高的权重。