seq2seq:LSTM+attention的生成式文本概要

seq2seq:LSTM+attention的生成式文本概要

最近做的利用seq2seq模型的生成式文本概要,参考了这位大佬的源码:
https://spaces.ac.cn/archives/5861/comment-page-1

数据集准备及预处理

我直接拿的新闻数据集的内容(content)和标题(title),根据内容概括标题。
一般想要达到比较能看的结果的话需要10w左右的数据集,跑50次迭代左右。
这种数据集网上很多,自己去找然后处理一下就好了。
数据集的预处理我是只保留了中文,去空格,最后所有文本都是连在一起的:

#正则表达式去除非中文字符
delCop = re.compile("[^\u4e00-\u9fa5]")
changeCop=re.compile("[^\u4e00-\u9fa5]")
for i in range(0, len(trainSet)):
    trainSet.iloc[i,1] = changeCop.sub(' ', delCop.sub('', trainSet.iloc[i,1]))
    trainSet.iloc[i,2] = changeCop.sub(' ', delCop.sub('', trainSet.iloc[i,2]))

生成式文本摘要与seq2seq

sequence2sequence就是利用一个encoder与一个decoder,将需要处理的原始文本投进encoder生成一个理论上的“中间码”,再有decoder解码输出为结果:
seq2seq:LSTM+attention的生成式文本概要_第1张图片

# encoder,双层双向LSTM
x = LayerNormalization()(x)
x = OurBidirectional(CuDNNLSTM(z_dim // 2, return_sequences=True))([x, x_mask])
x = LayerNormalization()(x)
x = OurBidirectional(CuDNNLSTM(z_dim // 2, return_sequences=True))([x, x_mask])
x_max = Lambda(seq_maxpool)([x, x_mask])

# decoder,双层单向LSTM
y = SelfModulatedLayerNormalization(z_dim // 4)([y, x_max])
y = CuDNNLSTM(z_dim, return_sequences=True)(y)
y = SelfModulatedLayerNormalization(z_dim // 4)([y, x_max])
y = CuDNNLSTM(z_dim, return_sequences=True)(y)
y = SelfModulatedLayerNormalization(z_dim // 4)([y, x_max])

最后也是在评价器中放了两个句子进行的调用解码进行的输出:

s1 = u'夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及时就医 。'
s2 = u'8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午,华住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。'

class Evaluate(Callback):
    def __init__(self):
        self.lowest = 1e10
    def on_epoch_end(self, epoch, logs=None):
        # 训练过程中观察一两个例子,显示标题质量提高的过程
        resStr=s1+'\n输出:'+gen_sent(s1)+'\n'+s2+'\n输出:'+gen_sent(s2)+'\n'
        with open('output', 'a',encoding='utf-8') as file_obj:
            file_obj.write(resStr)
        print(resStr)
        # 保存最优结果
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save_weights('./best_model.weights')

attention

attention是一种编码机制,用于形容词与词之间的注意力关系,比如下面这句话:
The animal didn’t cross the street because it was too tired

这句话中的"it"指的是什么?它指的是“animal”还是“street”?对于人来说,这其实是一个很简单的问题,但是对于一个算法来说,处理这个问题其实并不容易。self attention的出现就是为了解决这个问题,通过self attention,我们能将“it”与“animal”联系起来。
由于有时候一个词可能与多个词有较大关联,所以我们采用了一种叫做“多头”的策略。
比如上面的句子,it的注意力会集中在animal和tired身上。
具体可以参照这篇博文:https://blog.csdn.net/qq_43012160/article/details/100782291
著名的transformer和bert的词编码就是基于attention机制的

fit_generator

为什么那么多人训练模型的时候不用fit用fit_generator?
fit_generator中传入的不是数据集,而是一个数据生成器,如果你数据量非常大无法读入内存,fit就用不了了,但用fit_generator就只要传一个生成器(一个函数)进去。一般生成器每次选取batch_size个数据进行处理,处理完抛进模型训练,再处理后batch_size个数据。
生成器的数据生成与模型的训练还是并行的。

def data_generator():
    # 数据生成器
    X,Y = [],[]
    i=0
    while True:
        sentence=data.loc[i%dataLen]
        X.append(str2id(sentence['content']))
        Y.append(str2id(sentence['title'], start_end=True))
        i=i+1
        if len(X) == batch_size:
            X = np.array(padding(X))
            Y = np.array(padding(Y))
            yield [X,Y], None
            X,Y = [],[]
#模型训练
evaluator = Evaluate()

model.fit_generator(data_generator(),
                    steps_per_epoch=int(dataLen/batch_size),
                    epochs=epochs,
                    callbacks=[evaluator])

dataLen是数据的总长度,steps_per_epoch就是指每次迭代fit_generator执行的步数,dataLen=batch_size* steps_per_epoch。就是每步执行batch_size条数据。

放两条比较好的结果:
1.夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及时就医 。
输出1:紫外线照射成长期伤害长期伤害
输出2:夏天来临天后护肤品要慎重要

2.8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午,华住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。
输出1:连锁酒店用户数据泄露泄露
输出1:客户数据泄密牌酒店用户数据泄密店用户数据泄露

现在这个模型的vocab是单字的,后面打算用jieba分一下词,分完词就可以判断词之间的相似性,能判断相似性这种“数据泄密牌酒店用户数据泄密店用户数据泄露”的情况我就能把他检测出来然后做处理了。

你可能感兴趣的:(NLP,深度学习)