transformer做文本分类的keras实现完整版

背景

目前csdn上搜索到的keras的版本实现,排在前面的是:
https://blog.csdn.net/xiaosongshine/article/details/86595847

但是,这个文章存在挺多问题。本身这个文章的实现其实是少了一部分的(缺少了LayerNorm+残差的部分),multi-head attention的实现也少了一个 W o W_o Wo再做一次全连接映射。加上其本身运用的参数跟原始论文也差很多,所以跟论文描述的encoder区块其实对应不太上,如果是想对着论文来看代码的话,这段代码可能会产生一定的误导。所以我从各个地方找了其他的缺少的部分实现,凑出一个基本能对应上论文的keras版本的transformer-encoder完整的实现;另一方面,也顺便结合原理和代码(会尽量把注释写清楚),将transformer的原理重新复习一遍。

keras的版本

为了兼容csdn上看到的代码,keras的版本采用的是2.2.4的keras版本(非tf.kreas)。如果需要其他更高阶版本或者tf.keras的版本,可能会需要有一定的改动,可以参考GitHub上的CyberZHG的代码进行改动即可。

主要参考链接

原理主要参考链接:

  • https://zhuanlan.zhihu.com/p/44121378
  • https://zhuanlan.zhihu.com/p/44731789
  • https://blog.csdn.net/u012526436/article/details/86295971
  • 原始论文 https://arxiv.org/pdf/1706.03762.pdf

代码主要参考链接:

  • https://github.com/CyberZHG/keras-transformer
  • https://blog.csdn.net/xiaosongshine/article/details/86595847
  • https://blog.csdn.net/qq_40742298/article/details/115011147

模型整体结构

transformer做文本分类的keras实现完整版_第1张图片
因为是用来做文本分类,所以这个图里面我们只谈左边的encoder部分。
encoder部分首先是input + embedding部分,其次是由N个block组成的编码部分,在原文中,这个N是6。每个block呢,又由multi-head attention、add & norm 、feedforward和残连接层组成,我们接下来还是一步一步的拆解。

Input层

原始的Input层,为词向量+position embedding,这个跟一般的文本输入一样,假设输入为(batch_size, seq_len, embedding_size),注意一点的是,这个embedding_size为了在后续可以接上残差连接层,其应该要在整个网络中保证一致,原文中,这个embedding_size和各种子层的维度要一致,原文都是512维,以 d m o d e l = 512 d_{model}=512 dmodel=512表示。

Position embedding层

因为transformer与RNN不同,其没有了词位置顺序信息,因此为了保证位置信息,先将词过一个position embedding,然后再与词向量求和作为后续block的输入。注意一点的是,《Attention is all you need》原文提到了用sin和cos的方式以及训练词位置的embedding,经过实验发现二者没有区别,最后用的是sin和cos的方式。但是bert里面的position embedding是可训练的。
公式不赘述,大致表示如下:
transformer做文本分类的keras实现完整版_第2张图片
具体的代码实现及注释见下:

#! -*- coding: utf-8 -*-
#%%
from __future__ import print_function
from keras import backend as K
from keras.engine.topology import Layer
 
class Position_Embedding(Layer):
    def __init__(self, size=None, mode='sum', **kwargs):
        self.size = size #必须为偶数
        self.mode = mode
        super(Position_Embedding, self).__init__(**kwargs)
 
    def call(self, x): #上一层一般就是embedding层,batch_size,seq_len,model_dim
        if (self.size == None) or (self.mode == 'sum'):
            self.size = int(x.shape[-1]) #d_model的长度,比如512
        batch_size,seq_len = K.shape(x)[0],K.shape(x)[1] #
        ## K.arange(self.size / 2, dtype='float32' ), 生成0~256,间隔1,即公式中的i
        ## 2*K.arange(self.size / 2, dtype='float32' ), 0~512,间隔2,即公式中的2i, 0,2,4,6……,512,对应的i是0,1,2,3,4,5
        ## 再除以model_dim,按公式取pow
        position_j = 1. / K.pow(10000., 2 * K.arange(self.size / 2, dtype='float32' ) / self.size) #
        position_j = K.expand_dims(position_j, 0) # (1,256)
        #生成位置的序列
        #x[:,:,0]取每个embedding的第一个分量---> bs,seq_len
        #ones_like -->bs,seq_len [[1,1,1,1……],[1,1,1……],……]
        #cumsum ---> bs,seq_len,[[1,2,3,4……],[1,2,3……],……]
        #cumsum-1 ----->bs,seq_len,[[0,1,2,3……],[0,1,2……],……]
        position_i = K.cumsum(K.ones_like(x[:,:,0]), 1)-1 #K.arange不支持变长,只好用这种方法生成
        position_i = K.expand_dims(position_i, 2)#bs,seq_len,1
        position_ij = K.dot(position_i, position_j)#bs,seq_len,256
        ##经过dot之后,就是pe/10000^(2i/d_model)了
        ##原始的实现稍微有点问题,不应该直接concatenate偶数和奇数,应该交叉concatenate
        position_ij_2i = K.sin(position_ij)[...,tf.newaxis] #bs,seq_len,model_dim/2,1
        position_ij_2i_1 = K.cos(postition_ij)[...,tf.newaxis]#bs,seq_len,model_dim/2,1
        position_ij = K.concatenate([position_ij_2i,position_ij_2i_1])#bs,seq_len,model_dim/2,2
        position_ij = K.reshape(position_ij,(batch_size,seq_len,self.size)) #bs,seq_len,model_dim
        #position_ij = K.concatenate([K.cos(position_ij), K.sin(position_ij)], 2)#这个实现没有交叉拼接,前半部分都用的cos,后半部分都用的sin
        if self.mode == 'sum':
            return position_ij + x
        elif self.mode == 'concat':
            return K.concatenate([position_ij, x], 2)
 
    def compute_output_shape(self, input_shape):
        if self.mode == 'sum':
            return input_shape
        elif self.mode == 'concat':
            return (input_shape[0], input_shape[1], input_shape[2]+self.size)

单个block的各自实现

multi-head attention

首先,我们需要先实现单个的attention,如果不想按单个单个的attention实现,可以参考https://blog.csdn.net/xiaosongshine/article/details/86595847的attention层快速实现多个attention,不过需要添加一个Wo才能和论文完全一致,这里为了保证跟论文一致且拆解更清晰,我们先实现单个attention。

scaled dot attention

看一下scaled dot attention的示意图及公式:

  • 定义Wq,Wk,Wv三个矩阵
  • 分别用三个矩阵相乘得到Q,K ,V
  • Q,K dot得到分数,算softmax权重
  • 权重 * V矩阵得到最后的加权后的V矩阵(H矩阵)
  • 特别的是算softmax的时候要除以一个 D k \sqrt{D_{k}} Dk ,具体原因见https://blog.csdn.net/qq_37430422/article/details/105042303
    transformer做文本分类的keras实现完整版_第3张图片
    代码实现:
class ScaledDotProductAttention(Layer):
    r"""The attention layer that takes three inputs representing queries, keys and values.
    \text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}}) V
    See: https://arxiv.org/pdf/1706.03762.pdf
    """
    def __init__(self,
                 return_attention=False,
                 history_only=False,
                 **kwargs):
        """Initialize the layer.
        :param return_attention: Whether to return attention weights.
        :param history_only: Whether to only use history data.
        :param kwargs: Arguments for parent class.
        """
        super(ScaledDotProductAttention, self).__init__(**kwargs)
        self.supports_masking = True
        self.return_attention = return_attention
        self.history_only = history_only
        self.intensity = self.attention = None

    def get_config(self):
        config = {
            'return_attention': self.return_attention,
            'history_only': self.history_only,
        }
        base_config = super(ScaledDotProductAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            query_shape, key_shape, value_shape = input_shape
        else:
            query_shape = key_shape = value_shape = input_shape
        output_shape = query_shape[:-1] + value_shape[-1:]
        if self.return_attention:
            attention_shape = query_shape[:2] + (key_shape[1],)
            return [output_shape, attention_shape]
        return output_shape

    def compute_mask(self, inputs, mask=None):
        if isinstance(mask, list):
            mask = mask[0]
        if self.return_attention:
            return [mask, None]
        return mask

    def call(self, inputs, mask=None, **kwargs):
        if isinstance(inputs, list):
            query, key, value = inputs
        else:
            query = key = value = inputs
        if isinstance(mask, list):
            mask = mask[1]
        feature_dim = K.shape(query)[-1] #512
        #query = (bs,seq_len,dim)
        #key = (bs,seq_len,dim)
        #batch_dot后bs,seq_len,seq_len
        e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx()))
        if self.history_only:
            query_len, key_len = K.shape(query)[1], K.shape(key)[1]
            indices = K.expand_dims(K.arange(0, key_len), axis=0)
            upper = K.expand_dims(K.arange(0, query_len), axis=-1)
            e -= 10000.0 * K.expand_dims(K.cast(indices > upper, K.floatx()), axis=0)
        if mask is not None:
            e -= 10000.0 * (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx()))
        self.intensity = e
        e = K.exp(e - K.max(e, axis=-1, keepdims=True))
        self.attention = e / K.sum(e, axis=-1, keepdims=True)
        #self.attention = bs,seq_len,seq_len
        #value = bs,seq_len,dim
        #v = bs,seq_len,dim
        v = K.batch_dot(self.attention, value)
        if self.return_attention:
            return [v, self.attention]
        return v

multi-head attention

这个实现其实就是比较简单的了,把Q,K,V先映射一遍,然后切成num_head个块之后,再分别通过前面实现的scaled dot attention最后合并,然后再做一个映射即可,用Q举例看一下示意图:
(1)假设Q(bs=1,seq_len=10,dim=512)已经过了一个映射层,得到Q_的示意
transformer做文本分类的keras实现完整版_第4张图片
(2)同理得到的K_,计算Q_和K_计算dot attention矩阵
transformer做文本分类的keras实现完整版_第5张图片
(3)同理得到V_,加权求和Outputs
transformer做文本分类的keras实现完整版_第6张图片
(4)reshape回去
transformer做文本分类的keras实现完整版_第7张图片
(5)最后,再过一次Wo
transformer做文本分类的keras实现完整版_第8张图片
代码实现:

class MultiHeadAttention(Layer):
    """Multi-head attention layer.
    See: https://arxiv.org/pdf/1706.03762.pdf
    """

    def __init__(self,
                 head_num,
                 activation='relu',
                 use_bias=True,
                 kernel_initializer='glorot_normal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 history_only=False,
                 **kwargs):
        """Initialize the layer.
        :param head_num: Number of heads.
        :param activation: Activations for linear mappings.
        :param use_bias: Whether to use bias term.
        :param kernel_initializer: Initializer for linear mappings.
        :param bias_initializer: Initializer for linear mappings.
        :param kernel_regularizer: Regularizer for linear mappings.
        :param bias_regularizer: Regularizer for linear mappings.
        :param kernel_constraint: Constraints for linear mappings.
        :param bias_constraint: Constraints for linear mappings.
        :param history_only: Whether to only use history in attention layer.
        """
        self.supports_masking = True
        self.head_num = head_num
        self.activation = keras.activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.bias_initializer = keras.initializers.get(bias_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
        self.bias_regularizer = keras.regularizers.get(bias_regularizer)
        self.kernel_constraint = keras.constraints.get(kernel_constraint)
        self.bias_constraint = keras.constraints.get(bias_constraint)
        self.history_only = history_only

        self.Wq = self.Wk = self.Wv = self.Wo = None
        self.bq = self.bk = self.bv = self.bo = None

        self.intensity = self.attention = None
        super(MultiHeadAttention, self).__init__(**kwargs)

    def get_config(self):
        config = {
            'head_num': self.head_num,
            'activation': keras.activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
            'bias_initializer': keras.initializers.serialize(self.bias_initializer),
            'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
            'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
            'bias_constraint': keras.constraints.serialize(self.bias_constraint),
            'history_only': self.history_only,
        }
        base_config = super(MultiHeadAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            q, k, v = input_shape
            return q[:-1] + (v[-1],)
        return input_shape

    def compute_mask(self, inputs, input_mask=None):
        if isinstance(input_mask, list):
            return input_mask[0]
        return input_mask

    def build(self, input_shape):
        if isinstance(input_shape, list):
            q, k, v = input_shape
        else:
            q = k = v = input_shape
        feature_dim = int(v[-1])
        if feature_dim % self.head_num != 0:
            raise IndexError('Invalid head number %d with the given input dim %d' % (self.head_num, feature_dim))
        self.Wq = self.add_weight(
            shape=(int(q[-1]), feature_dim),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wq' % self.name,
        )
        if self.use_bias:
            self.bq = self.add_weight(
                shape=(feature_dim,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bq' % self.name,
            )
        self.Wk = self.add_weight(
            shape=(int(k[-1]), feature_dim),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wk' % self.name,
        )
        if self.use_bias:
            self.bk = self.add_weight(
                shape=(feature_dim,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bk' % self.name,
            )
        self.Wv = self.add_weight(
            shape=(int(v[-1]), feature_dim),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wv' % self.name,
        )
        if self.use_bias:
            self.bv = self.add_weight(
                shape=(feature_dim,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bv' % self.name,
            )
        self.Wo = self.add_weight(
            shape=(feature_dim, feature_dim),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            name='%s_Wo' % self.name,
        )
        if self.use_bias:
            self.bo = self.add_weight(
                shape=(feature_dim,),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                name='%s_bo' % self.name,
            )
        super(MultiHeadAttention, self).build(input_shape)

    @staticmethod
    def _reshape_to_batches(x, head_num):
        #split to head num
        input_shape = K.shape(x)
        batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
        head_dim = feature_dim // head_num
        x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
        ##为了方便scaled dot attention 计算(输入是bs, seq_len,head_dim),这里做了transpose和reshape
        x = K.permute_dimensions(x, [0, 2, 1, 3]) #transpose,把并行计算的head_num维度提到前面
        return K.reshape(x, (batch_size * head_num, seq_len, head_dim)) #reshape,因为bs轴在scaled dot里面不参与计算

    @staticmethod
    def _reshape_attention_from_batches(x, head_num):##attention得分矩阵的反向恢复
        input_shape = K.shape(x)
        batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
        x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))
        return K.permute_dimensions(x, [0, 2, 1, 3])

    @staticmethod
    def _reshape_from_batches(x, head_num):#attention后的向量恢复
        input_shape = K.shape(x)
        batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2] #bs*head_num,seq_len,head_dim
        x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))#bs,head_num,seq_len,head_dim
        x = K.permute_dimensions(x, [0, 2, 1, 3])#bs,seq_len,head_num,head_dim
        return K.reshape(x, (batch_size // head_num, seq_len, feature_dim * head_num)) #bs,seq_len,model_dim

    @staticmethod
    def _reshape_mask(mask, head_num):
        if mask is None:
            return mask   
        seq_len = K.shape(mask)[1]
        mask = K.expand_dims(mask, axis=1)
        mask = K.tile(mask, [1, head_num, 1])
        return K.reshape(mask, (-1, seq_len))

    def call(self, inputs, mask=None):
        if isinstance(inputs, list):
            q, k, v = inputs
        else:
            q = k = v = inputs #bs,seq_len,model_dim
        if isinstance(mask, list):
            q_mask, k_mask, v_mask = mask
        else:
            q_mask = k_mask = v_mask = mask
        q = K.dot(q, self.Wq) #先做变换再分成8个,和先分成8*64个再做变换,参数量都是一样的512*512
        k = K.dot(k, self.Wk)
        v = K.dot(v, self.Wv)
        if self.use_bias:
            q += self.bq
            k += self.bk
            v += self.bv
        if self.activation is not None:
            q = self.activation(q)
            k = self.activation(k)
            v = self.activation(v)
        scaled_dot_product_attention = ScaledDotProductAttention(
            history_only=self.history_only,
            name='%s-Attention' % self.name,
        )
        y = scaled_dot_product_attention(
            inputs=[
                self._reshape_to_batches(q, self.head_num), #query,bs*numhead,seq_len,dim,head_dim
                self._reshape_to_batches(k, self.head_num), #key
                self._reshape_to_batches(v, self.head_num), #value
            ],
            mask=[
                self._reshape_mask(q_mask, self.head_num),
                self._reshape_mask(k_mask, self.head_num),
                self._reshape_mask(v_mask, self.head_num),
            ],
        )
#       相似度矩阵
#         self.intensity = self._reshape_attention_from_batches(scaled_dot_product_attention.intensity, self.head_num)
#         self.attention = self._reshape_attention_from_batches(scaled_dot_product_attention.attention, self.head_num)
        y = self._reshape_from_batches(y, self.head_num) #合并
        y = K.dot(y, self.Wo) #最终输出
        if self.use_bias:
            y += self.bo
        if self.activation is not None:
            y = self.activation(y)

        # Add shape information to tensor
        input_shape = [K.int_shape(q), K.int_shape(k), K.int_shape(v)]
        output_shape = self.compute_output_shape(input_shape)
        if output_shape[1] is not None:
            output_shape = (-1,) + output_shape[1:]
            y = K.reshape(y, output_shape)
        return y

LayerNorm

代码:

class LayerNorm(Layer):
    def __init__(self,
                 center=True,
                 scale=False,
                 epsilon=None,
                 gamma_initializer='ones',
                 beta_initializer='zeros',
                 gamma_regularizer=None,
                 beta_regularizer=None,
                 gamma_constraint=None,
                 beta_constraint=None,
                 **kwargs
                 ):
        super(LayerNorm, self).__init__(**kwargs)
        self.supports_masking = True
        self.center = center
        self.scale = scale
        if epsilon is None:
            epsilon = K.epsilon() * K.epsilon()
        self.epsilon = epsilon
        self.gamma_initializer = keras.initializers.get(gamma_initializer)
        self.beta_initializer = keras.initializers.get(beta_initializer)
        self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
        self.beta_regularizer = keras.regularizers.get(beta_regularizer)
        self.gamma_constraint = keras.constraints.get(gamma_constraint)
        self.beta_constraint = keras.constraints.get(beta_constraint)
        self.gamma, self.beta = 0., 0.

    def call(self, inputs, **kwargs):
        mean = K.mean(inputs, axis=-1, keepdims=True)
        variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
        std = K.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        if self.scale:
            outputs *= self.gamma
        if self.center:
            outputs += self.beta
        return outputs

加上Add、FFN,形成一个完整的transformer block

def transformer_block(x,prefix):
    O_seq = MultiHeadAttention(head_num=8,name=f'{prefix}_att1')(x) #bs,words_len,dim
    O_seq_Add1 = Add(name=f'{prefix}_add1')([x,O_seq])
    O_seq_LN1 = LayerNorm(name=f'{prefix}_LN1')(O_seq_Add1) #X = LayerNorm(X + multihead(X))
    O_seq_fc1 = Dense(model_dim * 4,activation='relu',name=f'{prefix}_fc1')(O_seq_LN1) #FFN
    O_seq_fc2 = Dense(model_dim,name=f'{prefix}_fc2')(O_seq_fc1)
    O_seq_Add2 = Add(name=f'{prefix}_add2')([O_seq_LN1,O_seq_fc2])#
    O_seq_Add2 = add([O_seq_LN1,O_seq_fc2])
    O_seq_LN2 = LayerNorm(name=f'{prefix}_LN2')(O_seq_Add2)
    return O_seq_LN2

完整模型定义


MAX_LEN = 512
MODEL_DIM = 512

def load_word_embedding(filepath):
    embeddings_index = {}
    f = open(filepath, encoding='utf8')
    for line in tqdm(f):
        values = line.split()
        word = ''.join(values[:-MODEL_DIM])
        coefs = np.asarray(values[-MODEL_DIM:], dtype='float32')
        embeddings_index[word] = coefs
    f.close()
    return embeddings_index

def build_matrix(word_index, path):
    embedding_index = load_word_embedding(path) 
    embedding_matrix = np.zeros((len(word_index) + 1, MODEL_DIM))
    for word, i in word_index.items():
        if word in embedding_index:
            embedding_matrix[i] = embedding_index[word]
            #break
    return embedding_matrix

def transformer_block(x,prefix):
    O_seq = MultiHeadAttention(head_num=8,name=f'{prefix}_att1')(x) #bs,words_len,dim
    O_seq = Dropout(0.1,name=f'{prefix}_do1')(O_seq)
    O_seq_Add1 = Add(name=f'{prefix}_add1')([x,O_seq])
    O_seq_LN1 = LayerNorm(name=f'{prefix}_LN1')(O_seq_Add1) #X = LayerNorm(X + multihead(X))
    O_seq_fc1 = Dense(MODEL_DIM * 4,activation='relu',name=f'{prefix}_fc1')(O_seq_LN1) #FFN
    O_seq_fc2 = Dense(MODEL_DIM,name=f'{prefix}_fc2')(O_seq_fc1)
    O_seq_fc2 = Dropout(0.1,name=f'{prefix}_do2')(O_seq_fc2)
    O_seq_Add2 = Add(name=f'{prefix}_add2')([O_seq_LN1,O_seq_fc2])#
    O_seq_Add2 = add([O_seq_LN1,O_seq_fc2])
    O_seq_LN2 = LayerNorm(name=f'{prefix}_LN2')(O_seq_Add2)
    return O_seq_LN2


def build_model(embedding_matrix, num_class = 2):
    words = Input(shape=(MAX_LEN,),name='inputs',dtype='int32')
    embeddings = Embedding(*embedding_matrix.shape, weights=[embedding_matrix], trainable=True)(words)
    embeddings = Position_Embedding()(embeddings) #增加Position_Embedding能轻微提高准确率
    embeddings = Dropout(0.1)(embeddings)

    # def transformer_block(x,prefix):
    seq_len = K.shape(words)[1]
#     model_dim = K.int_shape(embeddings)[-1]
    
    O_seq1 = transformer_block(embeddings,prefix='t1')
    O_seq2 = transformer_block(O_seq1,prefix='t2')
    O_seq3 = transformer_block(O_seq2,prefix='t3')
    O_seq4 = transformer_block(O_seq3,prefix='t4')
    O_seq5 = transformer_block(O_seq4,prefix='t5')
    O_seq6 = transformer_block(O_seq5,prefix='t6')
#     O_seq7 = transformer_block(O_seq6,prefix='t7')
#     O_seq8 = transformer_block(O_seq7,prefix='t8')
    
    O_seq = Add()([O_seq4,O_seq5,O_seq6]) ###后面这块是自由发挥的
    O_seq = GlobalAveragePooling1D()(O_seq)
    O_seq = Dropout(0.1)(O_seq)
    
    #下面的这块原文用了warmup,我们不用了。
    
    result = Dense(num_class, activation='softmax', name='outputs')(O_seq)    
    model = Model(inputs=words, outputs=result)
    opt=keras.optimizers.Adam(lr=5e-5)
    model.compile(loss='categorical_crossentropy',optimizer=opt, metrics=['acc'])
    model.summary()
    return model

题外话

如果只用上面的这些代码来跑模型,你可能会发现模型收敛很困难,因为没有做learning rate的warm up,而这其实是很重要的,如果发现模型不收敛,可以尝试把LayerNorm放到attention和FFN之前,或者先尝试把Learning rate调小一点(5e-5及以下),还可以加上warmup策略。

参考:https://zhuanlan.zhihu.com/p/84614490
附上keras的warmup的实现,来源:
https://gitee.com/yangyin2020/keras_classfication/blob/master/warmup_cosine_decay_scheduler.py

可以自己根据需要修改:

import numpy as np
from tensorflow import keras
from keras import backend as K

# 带有warm-up的cosine学习率

def cosine_decay_with_warmup(global_step,
                             learning_rate_base,
                             total_steps,
                             warmup_learning_rate=0.0,
                             warmup_steps=0,
                             hold_base_rate_steps=0):
    """Cosine decay schedule with warm up period.

    Cosine annealing learning rate as described in:
      Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
      ICLR 2017. https://arxiv.org/abs/1608.03983
    In this schedule, the learning rate grows linearly from warmup_learning_rate
    to learning_rate_base for warmup_steps, then transitions to a cosine decay
    schedule.

    Arguments:
        global_step {int} -- global step.
        learning_rate_base {float} -- base learning rate.
        total_steps {int} -- total number of training steps.

    Keyword Arguments:
        warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
        warmup_steps {int} -- number of warmup steps. (default: {0})
        hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
                                    before decaying. (default: {0})
    Returns:
      a float representing learning rate.

    Raises:
      ValueError: if warmup_learning_rate is larger than learning_rate_base,
        or if warmup_steps is larger than total_steps.
    """

    if total_steps < warmup_steps:
        raise ValueError('total_steps must be larger or equal to '
                         'warmup_steps.')
    learning_rate = 0.5 * learning_rate_base * (1 + np.cos(
        np.pi *
        (global_step - warmup_steps - hold_base_rate_steps
         ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
    if hold_base_rate_steps > 0:
        learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,
                                 learning_rate, learning_rate_base)
    if warmup_steps > 0:
        if learning_rate_base < warmup_learning_rate:
            raise ValueError('learning_rate_base must be larger or equal to '
                             'warmup_learning_rate.')
        slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
        warmup_rate = slope * global_step + warmup_learning_rate
        learning_rate = np.where(global_step < warmup_steps, warmup_rate,
                                 learning_rate)
    return np.where(global_step > total_steps, 0.0, learning_rate)


class WarmUpCosineDecayScheduler(keras.callbacks.Callback):
    """Cosine decay with warmup learning rate scheduler
    """

    def __init__(self,
                 learning_rate_base,
                 total_steps,
                 global_step_init=0,
                 warmup_learning_rate=0.0,
                 warmup_steps=0,
                 hold_base_rate_steps=0,
                 verbose=0):
        """Constructor for cosine decay with warmup learning rate scheduler.

    Arguments:
        learning_rate_base {float} -- base learning rate.
        total_steps {int} -- total number of training steps.

    Keyword Arguments:
        global_step_init {int} -- initial global step, e.g. from previous checkpoint.
        warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
        warmup_steps {int} -- number of warmup steps. (default: {0})
        hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
                                    before decaying. (default: {0})
        verbose {int} -- 0: quiet, 1: update messages. (default: {0})
        """

        super(WarmUpCosineDecayScheduler, self).__init__()
        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.global_step = global_step_init
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_steps
        self.hold_base_rate_steps = hold_base_rate_steps
        self.verbose = verbose
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = K.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = cosine_decay_with_warmup(global_step=self.global_step,
                                      learning_rate_base=self.learning_rate_base,
                                      total_steps=self.total_steps,
                                      warmup_learning_rate=self.warmup_learning_rate,
                                      warmup_steps=self.warmup_steps,
                                      hold_base_rate_steps=self.hold_base_rate_steps)
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\nBatch %05d: setting learning '
                  'rate to %s.' % (self.global_step + 1, lr))

if __name__ == '__main__':
    from keras.models import Sequential
    from keras.layers import Dense
    # Create a model.
    model = Sequential()
    model.add(Dense(32, activation='relu', input_dim=100))
    model.add(Dense(10, activation='softmax'))
    model.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

    # Number of training samples.
    # gen1
    sample_count = 12608
    # gen

    # Total epochs to train.
    epochs = 50

    # Number of warmup epochs.
    warmup_epoch = 10

    # Training batch size, set small value here for demonstration purpose.
    batch_size = 16

    # Base learning rate after warmup.
    learning_rate_base = 0.0001

    total_steps = int(epochs * sample_count / batch_size)

    # Compute the number of warmup batches.
    warmup_steps = int(warmup_epoch * sample_count / batch_size)

    # Generate dummy data.
    data = np.random.random((sample_count, 100))
    labels = np.random.randint(10, size=(sample_count, 1))

    # Convert labels to categorical one-hot encoding.
    one_hot_labels = keras.utils.to_categorical(labels, num_classes=10)

    # Compute the number of warmup batches.
    warmup_batches = warmup_epoch * sample_count / batch_size

    # Create the Learning rate scheduler.
    warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
                                            total_steps=total_steps,
                                            warmup_learning_rate=4e-06,
                                            warmup_steps=warmup_steps,
                                            hold_base_rate_steps=5,
                                            )

    # Train the model, iterating on the data in batches of 32 samples
    model.fit(data, one_hot_labels, epochs=epochs, batch_size=batch_size,
            verbose=0, callbacks=[warm_up_lr])

    import matplotlib.pyplot as plt
    plt.plot(warm_up_lr.learning_rates)
    plt.xlabel('Step', fontsize=20)
    plt.ylabel('lr', fontsize=20)
    plt.axis([0, total_steps, 0, learning_rate_base*1.1])
    plt.xticks(np.arange(0, epochs, 1))
    plt.grid()
    plt.title('Cosine decay with warmup', fontsize=20)
    plt.show()

你可能感兴趣的:(机器学习,自然语言处理,transformer,keras,文本分类)