Udacity——基于RNN网络的情感分析(一)

在Udacity上课也有一阵子了,所以趁着假期时间把一些小项目总结一下,这样也有助于自己今后对知识的回顾。闲话不多说,让我们开始这个有趣的项目吧!该项目是基于RNN的情感分析(项目链接),用RNN网络而不用前馈网络是因为RNN可以通过对序列的学习提高预测的准确性。数据集我们采用的是电影评论及对应的标签,RNN网络结构如图所示。

Udacity——基于RNN网络的情感分析(一)_第1张图片
network_diagram

这里,我们通过嵌入层传入数据,原因是嵌入层输入数据要比独热编码有更高的表示效率(具体的内容可以参考Word2Vec原理笔记),嵌入层的数据将传入LSTM单元,最终LSTM输出的的数据将传入一个sigmoid输出层。我们使用sigmoid输出层是因为我们尝试预测文本的积极情感或者消极情感,最终输出层只有一个sigmoid激活函数单元。


数据的预处理

首先我们要将数据调整成适合传入网络的形式。既然我们用嵌入层,我们就要将文字转换为整数,这样的形式更加简洁。
具体的处理的方法要基于数据本身,以我们现有的数据集为例首先观察评论数据的形式

'bromwell high is a cartoon comedy . it ran at the same time as some other programs about school life such as teachers . my years in the teaching profession lead me to believe that bromwell high s satire is much closer to reality than is teachers . the scramble to survive financially the insightful students who can see right through their pathetic teachers pomp the pettiness of the whole situation all remind me of the schools i knew and their students . when i saw the episode in which a student repeatedly tried to burn down the school i immediately recalled . . . . . . . . . at . . . . . . . . . . high . a classic line inspector i m here to sack one of your teachers . student welcome to bromwel......'

因此我们要做以下两个工作:

  • 删去所有的标点符号
  • 根据回车键'\n'将评论逐条存入训练数组中

from string import punctuation
all_text = ''.join([c for c in reviews if c not in punctuation])
reviews = all_text.split('\n')

此时我们就得到了评论的数组。由于我们要将文字转换为整数,因此我们要对所有文字进行统计:

  • 合并所有评论
  • 逐词将单词存入数组

all_text = ' '.join(reviews)
words = all_text.split()

最终得到的单词数组形式如下:

['bromwell',
'high',
'is',
'a',
'cartoon',
'comedy',
'it',
'ran',
'at',
'the',
'same',
'time',
'as',
'some',
'other',
'programs',
'about',
'school',
'life',
'such',
'as',
'teachers'...]


评论文字转数字

我们可以通过加工word数组,得到word:num的字典,这里值得注意的一点是由于我们会将0作为输入网络向量的初始元素,所以文字对应的数字应该从1开始,代码简单易懂如下:

vocab_set = set(words)
vocab_list = list(words)
vocab_to_int = {word: (num+1) for num,word in enumerate(vocab_list)}

得到这张word:num字典,我们就可以将评论中的文字转换为数字:

reviews_ints = []
for review in reviews:
reviews_ints.append([vocab_to_int[word] for word in review.split()])

此时评论转化后的形式:

[[2287959,
6015634,
6020195,
6020170,
6005621,
6013460,
6020180,
6003912,
6020174,
6020166,
6018044,
6020193,
6020107,
6020139,
6020049,
5991273,
6019751,
6015635,
6018542,
6017437,
6020107,
6005105,
6020143,
6019526,
6020091,
6020166,
5988232,
5703051,
6018165,
6019810,
6020114,
6019578,
6020161,
2287959,
6015634,
6020111,
6004283...]]


标签文字转数字

标签中只有'positve','negative'以及'\n',我们的目的是将'positve'设定为1,'negative'设定为0并删去'\n'。

labels = labels.split('\n')
labels = np.array([1 if label == 'positive' else 0 for label in labels])代码如下:

最后得到的标签数组如下所示:

array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 1, 0, 1, 0])


目前为止,我们已经将评论和标签的文字均转换为了数字,但这只是数据处理的第一步,下一步我们要做的是删除长度为0的评论,因为这些评论毫无意义。要注意的一点是我们不能破坏评论—标签的位置对应关系,所以我们删减根据的是位置信息。

non_zero_idx = [num for num,review in enumerate(reviews_ints) if len(review) != 0]
reviews_ints = [reviews_ints[num] for num in non_zero_idx]
labels = np.array([labels[num] for num in non_zero_idx])

此时我们就可以对数据进行最后的统一维度处理,我们不难发现,评论文字的长短差距会很大,短则几个字,多则上千个字。因此我们设定200这个阈值,多于200个字的评论我们只保留前200个字;少于200个字的评论用0来填充剩余的长度,这样我们就统一了输入向量的维度。代码如下:

seq_len = 200
features = np.zeros((len(reviews_ints),seq_len),dtype=int)
for num,review in enumerate(reviews_ints):
features[num,-len(review):] = np.array(review)[:seq_len]

此时我们得到的最终可以输入到网络数据形式如下:

array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 21025, 308, 6,
3, 1050, 207, 8, 2138, 32, 1, 171, 57,
15, 49, 81, 5785, 44, 382, 110, 140, 15,
5194, 60, 154, 9, 1, 4975, 5852, 475, 71,
5, 260, 12, 21025, 308, 13, 1978, 6, 74,
2395...]...])


截止目前,我们已经将输入网络的数据和标签数据进行了处理,能够达到用于训练的要求,那么下一篇笔记将主要讲解该项目的网络模型。

你可能感兴趣的:(Udacity——基于RNN网络的情感分析(一))