python3 seq2seq_model.py 对应代码解读抽取式提取+生成式提取摘要代码解读------摘要代码解读5------第二章

接上一章节
上一章节调用完了adam优化器,这一章节重点介绍AutoSummary类别seq2seq解码器的操作

class AutoSummary(AutoRegressiveDecoder):

继承了之前的AutoRegressiveDecoder的类别
首先我们去查看AutoRegressiveDecoder类的示例

from bert4keras.snippets import AutoRegressiveDecoder

这里自定义的时候调用

autosummary = AutoSummary(
	start_id = tokenizer._token_start_id,
	end_id = tokenizer._token_end_id,
	maxlen = maxlen // 2
)

进入初始化AutoRegressiveDecoder的类别之中

class AutoRegressiveDecoder(object):
    """通用自回归生成模型解码基类
    包含beam search和random sample两种策略
    """
    def __init__(self, start_id, end_id, maxlen, minlen=1):
        self.start_id = start_id
        self.end_id = end_id
        self.maxlen = maxlen
        self.minlen = minlen
        self.models = {
     }
        if start_id is None:
            self.first_output_ids = np.empty((1, 0), dtype=int)
        else:
            self.first_output_ids = np.array([[self.start_id]])

这里的
start_id = 2
end_id = 3
self.maxlen = 512
(不明白这里为什么self.maxlen要除以2)。
由于AutoSummary(AutoRegressiveDecoder)类之中包含的函数众多,所以这里我们需要先看看有哪些函数被调用过
仔细打印对应的输出内容发现,调用了相应的AutoRegressiveDecoder wraps函数内容
AutoRegressiveDecoder warps前面加上了静态函数
使用__init__查看AutoSummary的初始化过程之后,发现了先调用的AutoRegressiveDecoder.wraps函数,然后才调用的初始化过程
调用AutoRegressiveDecoder.wraps函数在定义类别的过程中直接就进行调用了

@AutoRegressiveDecoder.wraps(default_rtype='logits',use_states=True)

这里直接调用AutoRegressiveDecoder中的wraps,因为wraps
之前AutoRegressiveDecoder中的wraps函数就被定义为@staticmethod的静态方法,所以这里可以通过类直接调用

@AutoRegressiveDecoder.wraps(default_rtype='logits', use_states=True)

然后再进行初始化

autosummary = AutoSummary(
	start_id = tokenizer._token_start_id,
	end_id = tokenizer._token_end_id,
	maxlen = maxlen//2
)

后面这些初始化的内容也没有用上,直接开始训练了2个epochs的训练集合数据

train_generator = data_generator(train_data,batch_size)
train_model.fit_generator(
	train_generator.forfit(),
	steps_per_epoch = len(train_generator),
	epochs = epochs,
	callbacks = [evaluator]
)

这里感觉不完善,应该是训练完每一个epoch之后就进行预测,保存在测试集合上最高的权重。

你可能感兴趣的:(文本摘要抽取代码解读,python,开发语言,后端)