循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递归神经网络(recursive neural network)。
它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种’记忆’功能.
RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。
对循环神经网络的研究始于二十世纪80-90年代,并在二十一世纪初发展为深度学习(deep learning)算法之一,其中双向循环神经网络(Bidirectional RNN, Bi-RNN)和长短期记忆网络(Long Short-Term Memory networks,LSTM)是常见的循环神经网络。
循环神经网络具有记忆性、参数共享并且图灵完备(Turing completeness),因此在对序列的非线性特征进行学习时具有一定优势。循环神经网络在自然语言处理(Natural Language Processing, NLP),例如语音识别、语言建模、机器翻译等领域有应用,也被用于各类时间序列预报。引入了卷积神经网络(Convoutional Neural Network,CNN)构筑的循环神经网络可以处理包含序列输入的计算机视觉问题。
其展开可以表示为:
那么数学表示的公式为:
h ∗ t = W h x x t + W h h h t − 1 + b h h t = σ ( h ∗ t ) o ∗ t = W o h h t + b o o t = θ ( o ∗ t ) h^{t}_{*} = W_{hx}x^{t} + W_{hh}h^{t-1} + b_{h} \\ h^{t} = \sigma(h^{t}_{*}) \\ o^{t}_{*} = W_{oh} h^{t} + b_{o}\\ o^{t} = \theta (o^{t}_{*}) h∗t=Whxxt+Whhht−1+bhht=σ(h∗t)o∗t=Wohht+boot=θ(o∗t)
其中, x t x^{t} xt表示t时刻的输入, o t o^{t} ot表示t时刻的输出, h t h^{t} ht表示t时刻隐藏层的状态。
由于每一步的输出不仅仅依赖当前步的网络,并且还需要前若干步网络的状态,那么这种BP改版的算法叫做Backpropagation Through Time(BPTT) , 也就是将输出端的误差值反向传递,运用梯度下降法进行更新.
较为严重的是容易出现梯度消失(时间过长而造成记忆值较小)或者梯度爆炸的问题(BP算法和长时间依赖造成的)
因此, 就出现了一系列的改进的算法, 最基础的两种算法是LSTM 和 GRU.
这两种方法在面对梯度消失或者梯度爆炸的问题时,由于有特殊的方式存储”记忆”,那么以前梯度比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服梯度消失问题;而针对梯度爆炸则设置阈值,超过阈值直接限制梯度。
LSTM(Long short-term memory,长短期记忆)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失问题。
LSTM是有4个全连接层进行计算的,LSTM的内部结构如下图所示。
LSTM的核心是细胞状态——最上层的横穿整个细胞的水平线,它通过门来控制信息的增加或者删除。
STM共有三个门,分别是遗忘门,输入门和输出门。
GRU是LSTM的变种,它也是一种RNN,因此是循环结构,相比LSTM而言,它的计算要简单一些,计算量也降低。
GRU 有两个有两个门,即一个重置门(reset gate)和一个更新门(update gate)。从直观上来说,重置门决定了如何将新的输入信息与前面的记忆相结合,更新门定义了前面记忆保存到当前时间步的量。如果我们将重置门设置为 1,更新门设置为 0,那么我们将再次获得标准 RNN 模型。使用门控机制学习长期依赖关系的基本思想和 LSTM 一致,但还是有一些关键区别:
GRU 有两个门(重置门与更新门),而 LSTM 有三个门(输入门、遗忘门和输出门)。
GRU 并不会控制并保留内部记忆(c_t),且没有 LSTM 中的输出门。
重置门:用来决定需要丢弃哪些上一个神经元细胞的信息,它的计算过程是将Ct-1与当前输入向量xt进行连接后,输入sigmoid层进行计算,结果为S1,再将S1与Ct-1进行点乘计算,则结果为保存的上个神经元细胞信息,用C’t-1表示。公式表示为:C’t-1 = Ct-1 · S1,S1 = sigmoid(concat(Ct-1,Xt))
更新门:更新门类似于LSTM的遗忘门和输入门,它决定哪些信息会丢弃,以及哪些新信息会增加。
本次使用的分类数据是从新浪微博不实信息举报平台抓取的中文谣言数据,数据集中共包含1538条谣言和1849条非谣言。 更多数据集介绍请参考https://github.com/thunlp/Chinese_Rumor_Dataset
import pandas as pd
all_data = pd.read_csv("data/data69671/all_data.tsv", sep="\t")
all_data.head()
label | text | |
---|---|---|
0 | 0 | #广州#【广州游行打砸抢罪犯资料公布!居然是日本间谍!】 |
1 | 0 | 【政协委员提议恢复大清王朝】康熙十世孙、广州政协委员金复新表示,他准备走遍中国收集100万人... |
2 | 1 | 有木有人和我一样。睡觉时头总爱靠在枕头的一角。据说这样的孩纸,都没安全感。 |
3 | 1 | 据说,看到这张图的人,许个愿,在十秒内转发的,就能美梦成真!!我们也试试!!! |
4 | 0 | 【老小子走了!李登辉今天凌晨心脏病复发身亡】台北消息:原国民党、台联党主席,有“台独教父”之... |
all_str = all_data["text"].values.tolist()
dict_set = set() # 保证每个字符只有唯一的对应数字
for content in all_str:
for s in content:
dict_set.add(s)
# 添加未知字符
dict_set.add("" )
# 把元组转换成字典,一个字对应一个数字
dict_list = []
i = 0
for s in dict_set:
dict_list.append([s, i])
i += 1
dict_txt = dict(dict_list)
# 字典保存到本地
with open("dict.txt", 'w', encoding='utf-8') as f:
f.write(str(dict_txt))
# 获取字典的长度
def get_dict_len(dict_path):
with open(dict_path, 'r', encoding='utf-8') as f:
line = eval(f.readlines()[0])
return len(line.keys())
print(get_dict_len("dict.txt"))
4410
all_data_list = all_data.values.tolist()
train_length = len(all_data) // 10 * 7
dev_length = len(all_data) // 10 * 2
train_data = []
dev_data = []
test_data = []
for i in range(train_length):
text = ""
for s in all_data_list[i][1]:
text = text + str(dict_txt[s]) + ","
text = text[:-1]
train_data.append([text, all_data_list[i][0]])
for i in range(train_length, train_length+dev_length):
text = ""
for s in all_data_list[i][1]:
text = text + str(dict_txt[s]) + ","
text = text[:-1]
dev_data.append([text, all_data_list[i][0]])
for i in range(train_length+dev_length, len(all_data)):
text = ""
for s in all_data_list[i][1]:
text = text + str(dict_txt[s]) + ","
text = text[:-1]
test_data.append([text, all_data_list[i][0]])
print(len(train_data))
print(len(dev_data))
print(len(test_data))
df_train = pd.DataFrame(columns=["text", "label"], data=train_data)
df_dev = pd.DataFrame(columns=["text", "label"], data=dev_data)
df_test = pd.DataFrame(columns=["text", "label"], data=test_data)
df_train.to_csv("train_data.csv", index=False)
df_dev.to_csv("dev_data.csv", index=False)
df_test.to_csv("test_data.csv", index=False)
2366
676
345
import numpy as np
import paddle
from paddle.io import Dataset, DataLoader
import pandas as pd
class MyDataset(Dataset):
"""
步骤一:继承paddle.io.Dataset类
"""
def __init__(self, mode='train'):
"""
步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集
"""
super(MyDataset, self).__init__()
self.label = True
if mode == 'train':
text = pd.read_csv("train_data.csv")["text"].values.tolist()
label = pd.read_csv("train_data.csv")["label"].values.tolist()
self.data = []
for i in range(len(text)):
self.data.append([])
self.data[-1].append(np.array([int(i) for i in text[i].split(",")]))
self.data[-1][0] = self.data[-1][0][:256].astype('int64')if len(self.data[-1][0])>=256 else np.concatenate([self.data[-1][0], np.array([dict_txt["" ]]*(256-len(self.data[-1][0])))]).astype('int64')
self.data[-1].append(np.array(int(label[i])).astype('int64'))
elif mode == 'dev':
text = pd.read_csv("dev_data.csv")["text"].values.tolist()
label = pd.read_csv("dev_data.csv")["label"].values.tolist()
self.data = []
for i in range(len(text)):
self.data.append([])
self.data[-1].append(np.array([int(i) for i in text[i].split(",")]))
self.data[-1][0] = self.data[-1][0][:256].astype('int64')if len(self.data[-1][0])>=256 else np.concatenate([self.data[-1][0], np.array([dict_txt["" ]]*(256-len(self.data[-1][0])))]).astype('int64')
self.data[-1].append(np.array(int(label[i])).astype('int64'))
else:
text = pd.read_csv("test_data.csv")["text"].values.tolist()
label = pd.read_csv("test_data.csv")["label"].values.tolist()
self.data = []
for i in range(len(text)):
self.data.append([])
self.data[-1].append(np.array([int(i) for i in text[i].split(",")]))
self.data[-1][0] = self.data[-1][0][:256].astype('int64')if len(self.data[-1][0])>=256 else np.concatenate([self.data[-1][0], np.array([dict_txt["" ]]*(256-len(self.data[-1][0])))]).astype('int64')
self.data[-1].append(np.array(int(label[i])).astype('int64'))
self.label = False
def __getitem__(self, index):
"""
步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
"""
text_ = self.data[index][0]
label_ = self.data[index][1]
if self.label:
return text_, label_
else:
return text_
def __len__(self):
"""
步骤四:实现__len__方法,返回数据集总数目
"""
return len(self.data)
train_data = MyDataset(mode="train")
dev_data = MyDataset(mode="dev")
test_data = MyDataset(mode="test")
BATCH_SIZE = 128
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
import paddle.nn as nn
inputs_dim = get_dict_len("dict.txt")
class myLSTM(nn.Layer):
def __init__(self):
super(myLSTM, self).__init__()
# num_embeddings (int) - 嵌入字典的大小, input中的id必须满足 0 =< id < num_embeddings 。 。
# embedding_dim (int) - 每个嵌入向量的维度。
# padding_idx (int|long|None) - padding_idx的配置区间为 [-weight.shape[0], weight.shape[0],如果配置了padding_idx,那么在训练过程中遇到此id时会被用
# sparse (bool) - 是否使用稀疏更新,在词嵌入权重较大的情况下,使用稀疏更新能够获得更快的训练速度及更小的内存/显存占用。
# weight_attr (ParamAttr|None) - 指定嵌入向量的配置,包括初始化方法,具体用法请参见 ParamAttr ,一般无需设置,默认值为None。
self.embedding = nn.Embedding(inputs_dim, 256)
# input_size (int) - 输入的大小。
# hidden_size (int) - 隐藏状态大小。
# num_layers (int,可选) - 网络层数。默认为1。
# direction (str,可选) - 网络迭代方向,可设置为forward或bidirect(或bidirectional)。默认为forward。
# time_major (bool,可选) - 指定input的第一个维度是否是time steps。默认为False。
# dropout (float,可选) - dropout概率,指的是出第一层外每层输入时的dropout概率。默认为0。
# weight_ih_attr (ParamAttr,可选) - weight_ih的参数。默认为None。
# weight_hh_attr (ParamAttr,可选) - weight_hh的参数。默认为None。
# bias_ih_attr (ParamAttr,可选) - bias_ih的参数。默认为None。
# bias_hh_attr (ParamAttr,可选) - bias_hh的参数。默认为None。
self.lstm = nn.LSTM(256, 256, num_layers=2, direction='bidirectional',dropout=0.5)
# in_features (int) – 线性变换层输入单元的数目。
# out_features (int) – 线性变换层输出单元的数目。
# weight_attr (ParamAttr, 可选) – 指定权重参数的属性。默认值为None,表示使用默认的权重参数属性,将权重参数初始化为0。具体用法请参见 ParamAttr 。
# bias_attr (ParamAttr|bool, 可选) – 指定偏置参数的属性。 bias_attr 为bool类型且设置为False时,表示不会为该层添加偏置。 bias_attr 如果设置为True或者None,则表示使用默认的偏置参数属性,将偏置参数初始化为0。具体用法请参见 ParamAttr 。默认值为None。
# name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。
self.linear = nn.Linear(in_features=256*2, out_features=2)
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
emb = self.dropout(self.embedding(inputs))
print(emb)
output, (hidden, _) = self.lstm(emb)
print("output:", output)
#output形状大小为[batch_size,seq_len,num_directions * hidden_size]
#hidden形状大小为[num_layers * num_directions, batch_size, hidden_size]
#把前向的hidden与后向的hidden合并在一起
hidden = paddle.concat((hidden[-2,:,:], hidden[-1,:,:]), axis = 1)
print(hidden)
hidden = self.dropout(hidden)
#hidden形状大小为[batch_size, hidden_size * num_directions]
return self.linear(hidden)
lstm_model = paddle.Model(myLSTM())
lstm_model.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=lstm_model.parameters()),
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy())
lstm_model.fit(train_loader,
dev_loader,
epochs=1,
batch_size=BATCH_SIZE,
verbose=1,
save_dir="work/lstm")
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/1
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return (isinstance(seq, collections.Sequence) and
Tensor(shape=[128, 256, 256], dtype=float32, place=CPUPlace, stop_gradient=False,
[[[-0.04651996, 0. , 0. , ..., 0.03685956, 0. , 0.01133572],
[ 0.06982405, 0.01399607, 0. , ..., 0. , 0.04192550, 0. ],
[ 0. , 0. , -0.07115781, ..., -0.04474840, 0. , -0.00947660],
...,
[ 0.00036002, 0.02194860, 0. , ..., 0. , 0. , -0.06254490],
[ 0. , 0.02194860, 0. , ..., 0. , 0. , 0. ],
[ 0. , 0.02194860, -0.04798179, ..., -0.00808975, 0. , -0.06254490]],
[[-0.02243162, 0. , -0.04848601, ..., 0.01318943, 0.05083232, 0. ],
[-0.03468777, 0. , 0. , ..., 0. , 0. , 0. ],
[-0.00250597, 0. , 0.00956149, ..., -0.05242553, 0.04173937, 0.04689348],
...,
[ 0.00036002, 0. , 0. , ..., -0.00808975, -0.05875421, -0.06254490],
[ 0.00036002, 0. , -0.04798179, ..., -0.00808975, 0. , -0.06254490],
[ 0. , 0. , -0.04798179, ..., 0. , 0. , 0. ]],
[[-0.05900055, 0.03810406, -0.02360879, ..., 0.00572890, 0.02618536, 0. ],
[ 0.01131869, 0. , 0. , ..., 0. , 0. , -0.02667695],
[ 0.03942792, -0.05557661, 0.02318244, ..., 0.04537995, 0.01851625, 0. ],
...,
[ 0. , 0.02194860, -0.04798179, ..., -0.00808975, -0.05875421, -0.06254490],
[ 0. , 0. , -0.04798179, ..., 0. , 0. , 0. ],
[ 0. , 0.02194860, -0.04798179, ..., 0. , 0. , 0. ]],
...,
[[-0.05900055, 0.03810406, -0.02360879, ..., 0.00572890, 0. , 0. ],
[ 0. , -0.03945251, 0. , ..., -0.03621368, 0. , 0. ],
[ 0.04446360, 0.01865993, 0. , ..., 0. , 0.07054503, -0.05907314],
...,
[ 0. , 0. , -0.04798179, ..., -0.00808975, 0. , -0.06254490],
[ 0.00036002, 0.02194860, -0.04798179, ..., 0. , 0. , 0. ],
[ 0. , 0.02194860, -0.04798179, ..., -0.00808975, -0.05875421, 0. ]],
[[ 0. , 0. , 0. , ..., 0.03685956, -0.06564632, 0. ],
[ 0. , 0.02998513, 0. , ..., 0.01318943, 0.05083232, -0.02697797],
[-0.02399679, 0. , 0. , ..., 0.03550680, 0. , 0. ],
...,
[ 0. , 0. , -0.04798179, ..., -0.00808975, -0.05875421, 0. ],
[ 0.00036002, 0. , -0.04798179, ..., 0. , 0. , 0. ],
[ 0. , 0. , 0. , ..., 0. , 0. , -0.06254490]],
[[-0.04651996, 0. , 0. , ..., 0. , -0.06564632, 0.01133572],
[ 0.05890563, -0.04621770, 0.02617000, ..., 0. , 0.02645203, 0. ],
[ 0. , 0. , 0. , ..., -0.01541087, -0.01680669, 0.06438880],
...,
[ 0.00036002, 0.02194860, 0. , ..., -0.00808975, -0.05875421, 0. ],
[ 0.00036002, 0.02194860, -0.04798179, ..., -0.00808975, 0. , -0.06254490],
[ 0.00036002, 0.02194860, -0.04798179, ..., -0.00808975, 0. , -0.06254490]]])
output: Tensor(shape=[128, 256, 512], dtype=float32, place=CPUPlace, stop_gradient=False,
[[[ 0.00938810, -0.00546899, 0.00342875, ..., 0.00907911, 0.03100064, 0.00325108],
[ 0.00690732, 0.00087167, 0.00982505, ..., 0.00243614, 0.02229443, 0.00180063],
[ 0.00080585, 0.00150115, 0.01130185, ..., 0.01030400, 0.02393062, 0.00775731],
...,
[-0.00172433, 0.01737818, 0.01828881, ..., -0.01210868, 0.01913041, 0.01308940],
[ 0.00401915, 0.01313572, 0.02785316, ..., -0.01023939, 0.01315692, 0.01490174],
[ 0.00594976, 0.00380577, 0.02129058, ..., -0.00132644, 0.00966622, 0.00500632]],
[[ 0.01139693, 0.00300420, 0.01319209, ..., 0.00867867, 0.02928928, 0.00532537],
[ 0.00867943, 0.00134485, 0.00759263, ..., 0.00232504, 0.03548140, 0.01460487],
[ 0.00545911, 0.00598109, 0.00665692, ..., -0.00179417, 0.02968132, 0.02142053],
...,
[ 0.00684091, -0.00332550, 0.01627226, ..., -0.00768097, 0.01906899, 0.00495052],
[ 0.00714776, 0.00359129, 0.00969665, ..., -0.00518371, 0.01270603, 0.02092562],
[ 0.00996211, -0.00300248, 0.01471507, ..., -0.00518170, 0.00491907, 0.01977420]],
[[ 0.00454305, 0.00557371, 0.00644129, ..., 0.00742812, 0.03228268, 0.00123904],
[ 0.01620130, 0.01076532, 0.00993367, ..., 0.01371160, 0.02940202, 0.00991743],
[ 0.00499136, 0.01006329, 0.00912400, ..., 0.00395425, 0.02999862, 0.00863243],
...,
[ 0.01431614, 0.00472833, 0.02303233, ..., -0.01040955, 0.02376084, 0.00746928],
[ 0.01137663, 0.00380168, 0.02216066, ..., -0.01175049, 0.01654375, 0.01299890],
[ 0.00578821, 0.01187050, 0.01773241, ..., -0.01594530, 0.01655355, 0.01706639]],
...,
[[ 0.00859276, -0.00315960, 0.01166933, ..., -0.00881801, 0.03494525, 0.01678915],
[ 0.01290557, 0.00317856, 0.00955159, ..., -0.01111973, 0.03810613, 0.01380979],
[ 0.01253064, 0.00664895, 0.02108969, ..., -0.01780477, 0.03851430, 0.01267604],
...,
[-0.00838452, 0.01019156, 0.01459483, ..., -0.01348441, 0.02485748, 0.01388637],
[-0.00674383, 0.00988053, 0.01526485, ..., -0.01920962, 0.02392912, 0.01231303],
[ 0.00178657, 0.00858035, 0.01583224, ..., -0.01086871, 0.00722162, 0.01038192]],
[[ 0.00699786, 0.00642271, 0.00745836, ..., -0.00313805, 0.03383606, 0.00034484],
[ 0.02290456, 0.01042972, 0.01237193, ..., 0.00948680, 0.03684252, 0.00118934],
[ 0.01437230, 0.01381475, 0.02222175, ..., 0.01054095, 0.03601860, 0.00561095],
...,
[-0.00898388, 0.00819177, 0.02560077, ..., -0.01066977, 0.02100313, 0.03179439],
[-0.00742552, 0.01296771, 0.02097484, ..., -0.00363565, 0.02471528, 0.01478868],
[-0.00114184, 0.00497478, 0.01264016, ..., -0.01004613, 0.00440474, 0.01142078]],
[[ 0.00540706, 0.00030641, 0.00700968, ..., 0.00050059, 0.03791260, -0.00777113],
[ 0.00696417, -0.00121238, 0.00928910, ..., 0.00291850, 0.03325276, -0.00500016],
[ 0.00232116, 0.00674826, 0.01029603, ..., 0.00446389, 0.03645320, -0.00496576],
...,
[-0.00207149, 0.01562872, 0.02173612, ..., -0.01239564, 0.01776450, 0.01830550],
[ 0.00535780, 0.00764524, 0.02233512, ..., -0.01778406, 0.01849146, 0.01427940],
[ 0.00507328, 0.00707623, 0.02679622, ..., -0.00783817, 0.01783962, 0.01461937]]])
Tensor(shape=[128, 512], dtype=float32, place=CPUPlace, stop_gradient=False,
[[ 0.00594976, 0.00380577, 0.02129058, ..., 0.00907911, 0.03100064, 0.00325108],
[ 0.00996211, -0.00300248, 0.01471507, ..., 0.00867867, 0.02928928, 0.00532537],
[ 0.00578821, 0.01187050, 0.01773241, ..., 0.00742812, 0.03228268, 0.00123904],
...,
[ 0.00178657, 0.00858035, 0.01583224, ..., -0.00881801, 0.03494525, 0.01678915],
[-0.00114184, 0.00497478, 0.01264016, ..., -0.00313805, 0.03383606, 0.00034484],
[ 0.00507328, 0.00707623, 0.02679622, ..., 0.00050059, 0.03791260, -0.00777113]])
result = lstm_model.predict(test_loader)
Predict begin...
step 3/3 [==============================] - 38ms/step
Predict samples: 345
class myGRU(nn.Layer):
def __init__(self):
super(myGRU, self).__init__()
# num_embeddings (int) - 嵌入字典的大小, input中的id必须满足 0 =< id < num_embeddings 。 。
# embedding_dim (int) - 每个嵌入向量的维度。
# padding_idx (int|long|None) - padding_idx的配置区间为 [-weight.shape[0], weight.shape[0],如果配置了padding_idx,那么在训练过程中遇到此id时会被用
# sparse (bool) - 是否使用稀疏更新,在词嵌入权重较大的情况下,使用稀疏更新能够获得更快的训练速度及更小的内存/显存占用。
# weight_attr (ParamAttr|None) - 指定嵌入向量的配置,包括初始化方法,具体用法请参见 ParamAttr ,一般无需设置,默认值为None。
self.embedding = nn.Embedding(inputs_dim, 256)
# input_size (int) - 输入的大小。
# hidden_size (int) - 隐藏状态大小。
# num_layers (int,可选) - 网络层数。默认为1。
# direction (str,可选) - 网络迭代方向,可设置为forward或bidirect(或bidirectional)。默认为forward。
# time_major (bool,可选) - 指定input的第一个维度是否是time steps。默认为False。
# dropout (float,可选) - dropout概率,指的是出第一层外每层输入时的dropout概率。默认为0。
# weight_ih_attr (ParamAttr,可选) - weight_ih的参数。默认为None。
# weight_hh_attr (ParamAttr,可选) - weight_hh的参数。默认为None。
# bias_ih_attr (ParamAttr,可选) - bias_ih的参数。默认为None。
# bias_hh_attr (ParamAttr,可选) - bias_hh的参数。默认为None。
self.gru = nn.GRU(256, 256, num_layers=2, direction='bidirectional',dropout=0.5)
# in_features (int) – 线性变换层输入单元的数目。
# out_features (int) – 线性变换层输出单元的数目。
# weight_attr (ParamAttr, 可选) – 指定权重参数的属性。默认值为None,表示使用默认的权重参数属性,将权重参数初始化为0。具体用法请参见 ParamAttr 。
# bias_attr (ParamAttr|bool, 可选) – 指定偏置参数的属性。 bias_attr 为bool类型且设置为False时,表示不会为该层添加偏置。 bias_attr 如果设置为True或者None,则表示使用默认的偏置参数属性,将偏置参数初始化为0。具体用法请参见 ParamAttr 。默认值为None。
# name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。
self.linear = nn.Linear(in_features=256*2, out_features=2)
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
emb = self.dropout(self.embedding(inputs))
output, hidden = self.gru(emb)
#output形状大小为[batch_size,seq_len,num_directions * hidden_size]
#hidden形状大小为[num_layers * num_directions, batch_size, hidden_size]
#把前向的hidden与后向的hidden合并在一起
hidden = paddle.concat((hidden[-2,:,:], hidden[-1,:,:]), axis = 1)
hidden = self.dropout(hidden)
#hidden形状大小为[batch_size, hidden_size * num_directions]
return self.linear(hidden)
GRU_model = paddle.Model(myGRU())
GRU_model.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=GRU_model.parameters()),
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy())
GRU_model.fit(train_loader,
dev_loader,
epochs=10,
batch_size=BATCH_SIZE,
verbose=1,
save_dir="work/GRU")
result = GRU_model.predict(test_loader)
Predict begin...
step 3/3 [==============================] - 35ms/step
Predict samples: 345
本文主要注重在于PaddlePaddle2.0在nlp基础任务的全流程如何实现,因此并未对两个模型的最终结果进行对比。
GRU和LSTM的性能在很多任务上效果相差不大,不过GRU 参数更少因此更容易收敛,而在数据集很大的情况下,LSTM表达性能更好。
在简单任务上,LSTM和GRU其实都是不错的选择,从完成代码来说,两者差别也不大,都可以简单方便的实现。
运行代码请点击:https://aistudio.baidu.com/aistudio/projectdetail/1491175?shared=1
欢迎三连!