【深度学习】IMDB数据集上电影评论二分类

任务描述

根据电影评论的文字内容来将电影划分为正面或者负面。

IMDB数据集

50000条两级分化的评论。正面负面各为50%。

# 加载数据
from keras.datasets import imdb
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000) # 仅保留训练数据中前10000个最经常出现的单词,低频单词被舍弃
Using TensorFlow backend.


Downloading data from https://s3.amazonaws.com/text-datasets/imdb.npz
17465344/17464789 [==============================] - 1s 0us/step
train_data.shape
(25000,)
len(train_data[0])
218
train_labels[0] # 0表示负面,1表示正面
1
len(train_data[1])
189
len(train_data[100])
158
max([max(sequence) for sequence in train_data])
9999
# 将某条评论解码为英文单词
word_index = imdb.get_word_index()
Downloading data from https://s3.amazonaws.com/text-datasets/imdb_word_index.json
1646592/1641221 [==============================] - 0s 0us/step
word_index
{'fawn': 34701,
 'tsukino': 52006,
 'nunnery': 52007,
 'sonja': 16816,
 'vani': 63951,
 'woods': 1408,
 'spiders': 16115,
 'hanging': 2345,
 'woody': 2289,
 ...}
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) # 键值翻转
decode_review = ' '.join([reverse_word_index.get(i-3, '?') for i in train_data[0]]) # 评论解码,索引去掉3,0为填充,1为序列开始,2位unknown
decode_review
"? this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert ? is an amazing actor and now the same being director ? father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for ? and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also ? to the two little boy's that played the ? of norman and paul they were just brilliant children are often left out of the ? list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all"
train_data
array([list([1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]),
       list([1, 194, 1153, 194, 8255, 78, 228, 5, 6, 1463, 4369, 5012, 134, 26, 4, 715, 8, 118, 1634, 14, 394, 20, 13, 119, 954, 189, 102, 5, 207, 110, 3103, 21, 14, 69, 188, 8, 30, 23, 7, 4, 249, 126, 93, 4, 114, 9, 2300, 1523, 5, 647, 4, 116, 9, 35, 8163, 4, 229, 9, 340, 1322, 4, 118, 9, 4, 130, 4901, 19, 4, 1002, 5, 89, 29, 952, 46, 37, 4, 455, 9, 45, 43, 38, 1543, 1905, 398, 4, 1649, 26, 6853, 5, 163, 11, 3215, 2, 4, 1153, 9, 194, 775, 7, 8255, 2, 349, 2637, 148, 605, 2, 8003, 15, 123, 125, 68, 2, 6853, 15, 349, 165, 4362, 98, 5, 4, 228, 9, 43, 2, 1157, 15, 299, 120, 5, 120, 174, 11, 220, 175, 136, 50, 9, 4373, 228, 8255, 5, 2, 656, 245, 2350, 5, 4, 9837, 131, 152, 491, 18, 2, 32, 7464, 1212, 14, 9, 6, 371, 78, 22, 625, 64, 1382, 9, 8, 168, 145, 23, 4, 1690, 15, 16, 4, 1355, 5, 28, 6, 52, 154, 462, 33, 89, 78, 285, 16, 145, 95]),
       list([1, 14, 47, 8, 30, 31, 7, 4, 249, 108, 7, 4, 5974, 54, 61, 369, 13, 71, 149, 14, 22, 112, 4, 2401, 311, 12, 16, 3711, 33, 75, 43, 1829, 296, 4, 86, 320, 35, 534, 19, 263, 4821, 1301, 4, 1873, 33, 89, 78, 12, 66, 16, 4, 360, 7, 4, 58, 316, 334, 11, 4, 1716, 43, 645, 662, 8, 257, 85, 1200, 42, 1228, 2578, 83, 68, 3912, 15, 36, 165, 1539, 278, 36, 69, 2, 780, 8, 106, 14, 6905, 1338, 18, 6, 22, 12, 215, 28, 610, 40, 6, 87, 326, 23, 2300, 21, 23, 22, 12, 272, 40, 57, 31, 11, 4, 22, 47, 6, 2307, 51, 9, 170, 23, 595, 116, 595, 1352, 13, 191, 79, 638, 89, 2, 14, 9, 8, 106, 607, 624, 35, 534, 6, 227, 7, 129, 113]),
       ...,
       list([1, 11, 6, 230, 245, 6401, 9, 6, 1225, 446, 2, 45, 2174, 84, 8322, 4007, 21, 4, 912, 84, 2, 325, 725, 134, 2, 1715, 84, 5, 36, 28, 57, 1099, 21, 8, 140, 8, 703, 5, 2, 84, 56, 18, 1644, 14, 9, 31, 7, 4, 9406, 1209, 2295, 2, 1008, 18, 6, 20, 207, 110, 563, 12, 8, 2901, 2, 8, 97, 6, 20, 53, 4767, 74, 4, 460, 364, 1273, 29, 270, 11, 960, 108, 45, 40, 29, 2961, 395, 11, 6, 4065, 500, 7, 2, 89, 364, 70, 29, 140, 4, 64, 4780, 11, 4, 2678, 26, 178, 4, 529, 443, 2, 5, 27, 710, 117, 2, 8123, 165, 47, 84, 37, 131, 818, 14, 595, 10, 10, 61, 1242, 1209, 10, 10, 288, 2260, 1702, 34, 2901, 2, 4, 65, 496, 4, 231, 7, 790, 5, 6, 320, 234, 2766, 234, 1119, 1574, 7, 496, 4, 139, 929, 2901, 2, 7750, 5, 4241, 18, 4, 8497, 2, 250, 11, 1818, 7561, 4, 4217, 5408, 747, 1115, 372, 1890, 1006, 541, 9303, 7, 4, 59, 2, 4, 3586, 2]),
       list([1, 1446, 7079, 69, 72, 3305, 13, 610, 930, 8, 12, 582, 23, 5, 16, 484, 685, 54, 349, 11, 4120, 2959, 45, 58, 1466, 13, 197, 12, 16, 43, 23, 2, 5, 62, 30, 145, 402, 11, 4131, 51, 575, 32, 61, 369, 71, 66, 770, 12, 1054, 75, 100, 2198, 8, 4, 105, 37, 69, 147, 712, 75, 3543, 44, 257, 390, 5, 69, 263, 514, 105, 50, 286, 1814, 23, 4, 123, 13, 161, 40, 5, 421, 4, 116, 16, 897, 13, 2, 40, 319, 5872, 112, 6700, 11, 4803, 121, 25, 70, 3468, 4, 719, 3798, 13, 18, 31, 62, 40, 8, 7200, 4, 2, 7, 14, 123, 5, 942, 25, 8, 721, 12, 145, 5, 202, 12, 160, 580, 202, 12, 6, 52, 58, 2, 92, 401, 728, 12, 39, 14, 251, 8, 15, 251, 5, 2, 12, 38, 84, 80, 124, 12, 9, 23]),
       list([1, 17, 6, 194, 337, 7, 4, 204, 22, 45, 254, 8, 106, 14, 123, 4, 2, 270, 2, 5, 2, 2, 732, 2098, 101, 405, 39, 14, 1034, 4, 1310, 9, 115, 50, 305, 12, 47, 4, 168, 5, 235, 7, 38, 111, 699, 102, 7, 4, 4039, 9245, 9, 24, 6, 78, 1099, 17, 2345, 2, 21, 27, 9685, 6139, 5, 2, 1603, 92, 1183, 4, 1310, 7, 4, 204, 42, 97, 90, 35, 221, 109, 29, 127, 27, 118, 8, 97, 12, 157, 21, 6789, 2, 9, 6, 66, 78, 1099, 4, 631, 1191, 5, 2642, 272, 191, 1070, 6, 7585, 8, 2197, 2, 2, 544, 5, 383, 1271, 848, 1468, 2, 497, 2, 8, 1597, 8778, 2, 21, 60, 27, 239, 9, 43, 8368, 209, 405, 10, 10, 12, 764, 40, 4, 248, 20, 12, 16, 5, 174, 1791, 72, 7, 51, 6, 1739, 22, 4, 204, 131, 9])],
      dtype=object)

输出train_data可以看出,这是没办法直接用的,外层是一个一维数组,里面是list,从含义上看这是二维数组,而形式上则不是我们需要的(samples, word_indices)这种格式。

# 数据准备,将列表转换为张量
import numpy as np

def vectorize_sequences(sequences, dimension=10000):
  results = np.zeros((len(sequences), dimension))
  for i, sequence in enumerate(sequences):
    results[i, sequence] = 1. # sequence是个数组,按照数组对results进行选择
  return results

x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
x_train.shape
(25000, 10000)
x_train[0]
array([0., 1., 1., ..., 0., 0., 0.])
len(x_train[0])
10000
for i, sequence in enumerate(train_data):
  if i == 10:
    break
  print(i, sequence)
0 [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
1 [1, 194, 1153, 194, 8255, 78, 228, 5, 6, 1463, 4369, 5012, 134, 26, 4, 715, 8, 118, 1634, 14, 394, 20, 13, 119, 954, 189, 102, 5, 207, 110, 3103, 21, 14, 69, 188, 8, 30, 23, 7, 4, 249, 126, 93, 4, 114, 9, 2300, 1523, 5, 647, 4, 116, 9, 35, 8163, 4, 229, 9, 340, 1322, 4, 118, 9, 4, 130, 4901, 19, 4, 1002, 5, 89, 29, 952, 46, 37, 4, 455, 9, 45, 43, 38, 1543, 1905, 398, 4, 1649, 26, 6853, 5, 163, 11, 3215, 2, 4, 1153, 9, 194, 775, 7, 8255, 2, 349, 2637, 148, 605, 2, 8003, 15, 123, 125, 68, 2, 6853, 15, 349, 165, 4362, 98, 5, 4, 228, 9, 43, 2, 1157, 15, 299, 120, 5, 120, 174, 11, 220, 175, 136, 50, 9, 4373, 228, 8255, 5, 2, 656, 245, 2350, 5, 4, 9837, 131, 152, 491, 18, 2, 32, 7464, 1212, 14, 9, 6, 371, 78, 22, 625, 64, 1382, 9, 8, 168, 145, 23, 4, 1690, 15, 16, 4, 1355, 5, 28, 6, 52, 154, 462, 33, 89, 78, 285, 16, 145, 95]
2 [1, 14, 47, 8, 30, 31, 7, 4, 249, 108, 7, 4, 5974, 54, 61, 369, 13, 71, 149, 14, 22, 112, 4, 2401, 311, 12, 16, 3711, 33, 75, 43, 1829, 296, 4, 86, 320, 35, 534, 19, 263, 4821, 1301, 4, 1873, 33, 89, 78, 12, 66, 16, 4, 360, 7, 4, 58, 316, 334, 11, 4, 1716, 43, 645, 662, 8, 257, 85, 1200, 42, 1228, 2578, 83, 68, 3912, 15, 36, 165, 1539, 278, 36, 69, 2, 780, 8, 106, 14, 6905, 1338, 18, 6, 22, 12, 215, 28, 610, 40, 6, 87, 326, 23, 2300, 21, 23, 22, 12, 272, 40, 57, 31, 11, 4, 22, 47, 6, 2307, 51, 9, 170, 23, 595, 116, 595, 1352, 13, 191, 79, 638, 89, 2, 14, 9, 8, 106, 607, 624, 35, 534, 6, 227, 7, 129, 113]
3 [1, 4, 2, 2, 33, 2804, 4, 2040, 432, 111, 153, 103, 4, 1494, 13, 70, 131, 67, 11, 61, 2, 744, 35, 3715, 761, 61, 5766, 452, 9214, 4, 985, 7, 2, 59, 166, 4, 105, 216, 1239, 41, 1797, 9, 15, 7, 35, 744, 2413, 31, 8, 4, 687, 23, 4, 2, 7339, 6, 3693, 42, 38, 39, 121, 59, 456, 10, 10, 7, 265, 12, 575, 111, 153, 159, 59, 16, 1447, 21, 25, 586, 482, 39, 4, 96, 59, 716, 12, 4, 172, 65, 9, 579, 11, 6004, 4, 1615, 5, 2, 7, 5168, 17, 13, 7064, 12, 19, 6, 464, 31, 314, 11, 2, 6, 719, 605, 11, 8, 202, 27, 310, 4, 3772, 3501, 8, 2722, 58, 10, 10, 537, 2116, 180, 40, 14, 413, 173, 7, 263, 112, 37, 152, 377, 4, 537, 263, 846, 579, 178, 54, 75, 71, 476, 36, 413, 263, 2504, 182, 5, 17, 75, 2306, 922, 36, 279, 131, 2895, 17, 2867, 42, 17, 35, 921, 2, 192, 5, 1219, 3890, 19, 2, 217, 4122, 1710, 537, 2, 1236, 5, 736, 10, 10, 61, 403, 9, 2, 40, 61, 4494, 5, 27, 4494, 159, 90, 263, 2311, 4319, 309, 8, 178, 5, 82, 4319, 4, 65, 15, 9225, 145, 143, 5122, 12, 7039, 537, 746, 537, 537, 15, 7979, 4, 2, 594, 7, 5168, 94, 9096, 3987, 2, 11, 2, 4, 538, 7, 1795, 246, 2, 9, 2, 11, 635, 14, 9, 51, 408, 12, 94, 318, 1382, 12, 47, 6, 2683, 936, 5, 6307, 2, 19, 49, 7, 4, 1885, 2, 1118, 25, 80, 126, 842, 10, 10, 2, 2, 4726, 27, 4494, 11, 1550, 3633, 159, 27, 341, 29, 2733, 19, 4185, 173, 7, 90, 2, 8, 30, 11, 4, 1784, 86, 1117, 8, 3261, 46, 11, 2, 21, 29, 9, 2841, 23, 4, 1010, 2, 793, 6, 2, 1386, 1830, 10, 10, 246, 50, 9, 6, 2750, 1944, 746, 90, 29, 2, 8, 124, 4, 882, 4, 882, 496, 27, 2, 2213, 537, 121, 127, 1219, 130, 5, 29, 494, 8, 124, 4, 882, 496, 4, 341, 7, 27, 846, 10, 10, 29, 9, 1906, 8, 97, 6, 236, 2, 1311, 8, 4, 2, 7, 31, 7, 2, 91, 2, 3987, 70, 4, 882, 30, 579, 42, 9, 12, 32, 11, 537, 10, 10, 11, 14, 65, 44, 537, 75, 2, 1775, 3353, 2, 1846, 4, 2, 7, 154, 5, 4, 518, 53, 2, 2, 7, 3211, 882, 11, 399, 38, 75, 257, 3807, 19, 2, 17, 29, 456, 4, 65, 7, 27, 205, 113, 10, 10, 2, 4, 2, 2, 9, 242, 4, 91, 1202, 2, 5, 2070, 307, 22, 7, 5168, 126, 93, 40, 2, 13, 188, 1076, 3222, 19, 4, 2, 7, 2348, 537, 23, 53, 537, 21, 82, 40, 2, 13, 2, 14, 280, 13, 219, 4, 2, 431, 758, 859, 4, 953, 1052, 2, 7, 5991, 5, 94, 40, 25, 238, 60, 2, 4, 2, 804, 2, 7, 4, 9941, 132, 8, 67, 6, 22, 15, 9, 283, 8, 5168, 14, 31, 9, 242, 955, 48, 25, 279, 2, 23, 12, 1685, 195, 25, 238, 60, 796, 2, 4, 671, 7, 2804, 5, 4, 559, 154, 888, 7, 726, 50, 26, 49, 7008, 15, 566, 30, 579, 21, 64, 2574]
4 [1, 249, 1323, 7, 61, 113, 10, 10, 13, 1637, 14, 20, 56, 33, 2401, 18, 457, 88, 13, 2626, 1400, 45, 3171, 13, 70, 79, 49, 706, 919, 13, 16, 355, 340, 355, 1696, 96, 143, 4, 22, 32, 289, 7, 61, 369, 71, 2359, 5, 13, 16, 131, 2073, 249, 114, 249, 229, 249, 20, 13, 28, 126, 110, 13, 473, 8, 569, 61, 419, 56, 429, 6, 1513, 18, 35, 534, 95, 474, 570, 5, 25, 124, 138, 88, 12, 421, 1543, 52, 725, 6397, 61, 419, 11, 13, 1571, 15, 1543, 20, 11, 4, 2, 5, 296, 12, 3524, 5, 15, 421, 128, 74, 233, 334, 207, 126, 224, 12, 562, 298, 2167, 1272, 7, 2601, 5, 516, 988, 43, 8, 79, 120, 15, 595, 13, 784, 25, 3171, 18, 165, 170, 143, 19, 14, 5, 7224, 6, 226, 251, 7, 61, 113]
5 [1, 778, 128, 74, 12, 630, 163, 15, 4, 1766, 7982, 1051, 2, 32, 85, 156, 45, 40, 148, 139, 121, 664, 665, 10, 10, 1361, 173, 4, 749, 2, 16, 3804, 8, 4, 226, 65, 12, 43, 127, 24, 2, 10, 10]
6 [1, 6740, 365, 1234, 5, 1156, 354, 11, 14, 5327, 6638, 7, 1016, 2, 5940, 356, 44, 4, 1349, 500, 746, 5, 200, 4, 4132, 11, 2, 9363, 1117, 1831, 7485, 5, 4831, 26, 6, 2, 4183, 17, 369, 37, 215, 1345, 143, 2, 5, 1838, 8, 1974, 15, 36, 119, 257, 85, 52, 486, 9, 6, 2, 8564, 63, 271, 6, 196, 96, 949, 4121, 4, 2, 7, 4, 2212, 2436, 819, 63, 47, 77, 7175, 180, 6, 227, 11, 94, 2494, 2, 13, 423, 4, 168, 7, 4, 22, 5, 89, 665, 71, 270, 56, 5, 13, 197, 12, 161, 5390, 99, 76, 23, 2, 7, 419, 665, 40, 91, 85, 108, 7, 4, 2084, 5, 4773, 81, 55, 52, 1901]
7 [1, 4, 2, 716, 4, 65, 7, 4, 689, 4367, 6308, 2343, 4804, 2, 2, 5270, 2, 2315, 2, 2, 2, 2, 4, 2, 628, 7685, 37, 9, 150, 4, 9820, 4069, 11, 2909, 4, 2, 847, 313, 6, 176, 2, 9, 6202, 138, 9, 4434, 19, 4, 96, 183, 26, 4, 192, 15, 27, 5842, 799, 7101, 2, 588, 84, 11, 4, 3231, 152, 339, 5206, 42, 4869, 2, 6293, 345, 4804, 2, 142, 43, 218, 208, 54, 29, 853, 659, 46, 4, 882, 183, 80, 115, 30, 4, 172, 174, 10, 10, 1001, 398, 1001, 1055, 526, 34, 3717, 2, 5262, 2, 17, 4, 6706, 1094, 871, 64, 85, 22, 2030, 1109, 38, 230, 9, 4, 4324, 2, 251, 5056, 1034, 195, 301, 14, 16, 31, 7, 4, 2, 8, 783, 2, 33, 4, 2945, 103, 465, 2, 42, 845, 45, 446, 11, 1895, 19, 184, 76, 32, 4, 5310, 207, 110, 13, 197, 4, 2, 16, 601, 964, 2152, 595, 13, 258, 4, 1730, 66, 338, 55, 5312, 4, 550, 728, 65, 1196, 8, 1839, 61, 1546, 42, 8361, 61, 602, 120, 45, 7304, 6, 320, 786, 99, 196, 2, 786, 5936, 4, 225, 4, 373, 1009, 33, 4, 130, 63, 69, 72, 1104, 46, 1292, 225, 14, 66, 194, 2, 1703, 56, 8, 803, 1004, 6, 2, 155, 11, 4, 2, 3231, 45, 853, 2029, 8, 30, 6, 117, 430, 19, 6, 8941, 9, 15, 66, 424, 8, 2337, 178, 9, 15, 66, 424, 8, 1465, 178, 9, 15, 66, 142, 15, 9, 424, 8, 28, 178, 662, 44, 12, 17, 4, 130, 898, 1686, 9, 6, 5623, 267, 185, 430, 4, 118, 2, 277, 15, 4, 1188, 100, 216, 56, 19, 4, 357, 114, 2, 367, 45, 115, 93, 788, 121, 4, 2, 79, 32, 68, 278, 39, 8, 818, 162, 4165, 237, 600, 7, 98, 306, 8, 157, 549, 628, 11, 6, 2, 13, 824, 15, 4104, 76, 42, 138, 36, 774, 77, 1059, 159, 150, 4, 229, 497, 8, 1493, 11, 175, 251, 453, 19, 8651, 189, 12, 43, 127, 6, 394, 292, 7, 8253, 4, 107, 8, 4, 2826, 15, 1082, 1251, 9, 906, 42, 1134, 6, 66, 78, 22, 15, 13, 244, 2519, 8, 135, 233, 52, 44, 10, 10, 466, 112, 398, 526, 34, 4, 1572, 4413, 6706, 1094, 225, 57, 599, 133, 225, 6, 227, 7, 541, 4323, 6, 171, 139, 7, 539, 2, 56, 11, 6, 3231, 21, 164, 25, 426, 81, 33, 344, 624, 19, 6, 4617, 7, 2, 2, 6, 5802, 4, 22, 9, 1082, 629, 237, 45, 188, 6, 55, 655, 707, 6371, 956, 225, 1456, 841, 42, 1310, 225, 6, 2493, 1467, 7722, 2828, 21, 4, 2, 9, 364, 23, 4, 2228, 2407, 225, 24, 76, 133, 18, 4, 189, 2293, 10, 10, 814, 11, 2, 11, 2642, 14, 47, 15, 682, 364, 352, 168, 44, 12, 45, 24, 913, 93, 21, 247, 2441, 4, 116, 34, 35, 1859, 8, 72, 177, 9, 164, 8, 901, 344, 44, 13, 191, 135, 13, 126, 421, 233, 18, 259, 10, 10, 4, 2, 6847, 4, 2, 3074, 7, 112, 199, 753, 357, 39, 63, 12, 115, 2, 763, 8, 15, 35, 3282, 1523, 65, 57, 599, 6, 1916, 277, 1730, 37, 25, 92, 202, 6, 8848, 44, 25, 28, 6, 22, 15, 122, 24, 4171, 72, 33, 32]
8 [1, 43, 188, 46, 5, 566, 264, 51, 6, 530, 664, 14, 9, 1713, 81, 25, 1135, 46, 7, 6, 20, 750, 11, 141, 4299, 5, 2, 4441, 102, 28, 413, 38, 120, 5533, 15, 4, 3974, 7, 5369, 142, 371, 318, 5, 955, 1713, 571, 2, 2, 122, 14, 8, 72, 54, 12, 86, 385, 46, 5, 14, 20, 9, 399, 8, 72, 150, 13, 161, 124, 6, 155, 44, 14, 159, 170, 83, 12, 5, 51, 6, 866, 48, 25, 842, 4, 1120, 25, 238, 79, 4, 547, 15, 14, 9, 31, 7, 148, 2, 102, 44, 35, 480, 3823, 2380, 19, 120, 4, 350, 228, 5, 269, 8, 28, 178, 1314, 2347, 7, 51, 6, 87, 65, 12, 9, 979, 21, 95, 24, 3186, 178, 11, 2, 14, 9, 24, 15, 20, 4, 84, 376, 4, 65, 14, 127, 141, 6, 52, 292, 7, 4751, 175, 561, 7, 68, 3866, 137, 75, 2541, 68, 182, 5, 235, 175, 333, 19, 98, 50, 9, 38, 76, 724, 4, 6750, 15, 166, 285, 36, 140, 143, 38, 76, 53, 3094, 1301, 4, 6991, 16, 82, 6, 87, 3578, 44, 2527, 7612, 5, 800, 4, 3033, 11, 35, 1728, 96, 21, 14, 22, 9, 76, 53, 7, 6, 406, 65, 13, 43, 219, 12, 639, 21, 13, 80, 140, 5, 135, 15, 14, 9, 31, 7, 4, 118, 3672, 13, 28, 126, 110]
9 [1, 14, 20, 47, 111, 439, 3445, 19, 12, 15, 166, 12, 216, 125, 40, 6, 364, 352, 707, 1187, 39, 294, 11, 22, 396, 13, 28, 8, 202, 12, 1109, 23, 94, 2, 151, 111, 211, 469, 4, 20, 13, 258, 546, 1104, 7273, 12, 16, 38, 78, 33, 211, 15, 12, 16, 2849, 63, 93, 12, 6, 253, 106, 10, 10, 48, 335, 267, 18, 6, 364, 1242, 1179, 20, 19, 6, 1009, 7, 1987, 189, 5, 6, 8419, 7, 2723, 2, 95, 1719, 6, 6035, 7, 3912, 7144, 49, 369, 120, 5, 28, 49, 253, 10, 10, 13, 1041, 19, 85, 795, 15, 4, 481, 9, 55, 78, 807, 9, 375, 8, 1167, 8, 794, 76, 7, 4, 58, 5, 4, 816, 9, 243, 7, 43, 50]
test_labels
array([0, 1, 1, ..., 0, 0, 0])
# 标签向量化
y_train = np.asarray(train_labels).astype('float32')
y_test = np.asarray(test_labels).astype('float32')

至此数据就准备好了,可以输入到神经网络中了。

# 构建网络
# 输入数据是一条向量,目标值为标量
# 这是最简单的情况
# 这类问题有个表现很好的模型:带relu激活函数的全连接层的简单堆叠
# 两个中间层,每层16个隐藏单元
# 第三层输出一个标量,预测当前评论的情感
# 最后一层用sigmoid激活函数输出一个概率值
# 16 --> 16 --> 1
from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(16, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
# 现在我们需要选择损失函数和优化器,
# 问题是一个二分类问题
# https://blog.csdn.net/u011240016/article/details/85150443
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
# 留出验证集
x_val = x_train[:5000]
partial_x_train = x_train[5000:]

y_val = y_train[:5000]
partial_y_train = y_train[5000:]

# 训练模型
history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512, validation_data=(x_val, y_val))
Train on 20000 samples, validate on 5000 samples
Epoch 1/20
20000/20000 [==============================] - 3s 136us/step - loss: 0.4713 - acc: 0.8052 - val_loss: 0.3679 - val_acc: 0.8578
Epoch 2/20
20000/20000 [==============================] - 2s 102us/step - loss: 0.2720 - acc: 0.9052 - val_loss: 0.2945 - val_acc: 0.8864
Epoch 3/20
20000/20000 [==============================] - 2s 102us/step - loss: 0.2026 - acc: 0.9308 - val_loss: 0.2698 - val_acc: 0.8910
Epoch 4/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.1671 - acc: 0.9416 - val_loss: 0.2799 - val_acc: 0.8906
Epoch 5/20
20000/20000 [==============================] - 2s 102us/step - loss: 0.1409 - acc: 0.9517 - val_loss: 0.2828 - val_acc: 0.8904
Epoch 6/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.1187 - acc: 0.9604 - val_loss: 0.3065 - val_acc: 0.8864
Epoch 7/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.1067 - acc: 0.9644 - val_loss: 0.3269 - val_acc: 0.8826
Epoch 8/20
20000/20000 [==============================] - 2s 102us/step - loss: 0.0866 - acc: 0.9718 - val_loss: 0.3504 - val_acc: 0.8816
Epoch 9/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.0778 - acc: 0.9752 - val_loss: 0.4557 - val_acc: 0.8616
Epoch 10/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.0706 - acc: 0.9775 - val_loss: 0.4003 - val_acc: 0.8736
Epoch 11/20
20000/20000 [==============================] - 2s 100us/step - loss: 0.0571 - acc: 0.9841 - val_loss: 0.4181 - val_acc: 0.8748
Epoch 12/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.0511 - acc: 0.9845 - val_loss: 0.4942 - val_acc: 0.8662
Epoch 13/20
20000/20000 [==============================] - 2s 100us/step - loss: 0.0430 - acc: 0.9881 - val_loss: 0.5034 - val_acc: 0.8662
Epoch 14/20
20000/20000 [==============================] - 2s 100us/step - loss: 0.0397 - acc: 0.9882 - val_loss: 0.5054 - val_acc: 0.8722
Epoch 15/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.0297 - acc: 0.9921 - val_loss: 0.5720 - val_acc: 0.8630
Epoch 16/20
20000/20000 [==============================] - 2s 100us/step - loss: 0.0296 - acc: 0.9914 - val_loss: 0.5725 - val_acc: 0.8674
Epoch 17/20
20000/20000 [==============================] - 2s 100us/step - loss: 0.0214 - acc: 0.9952 - val_loss: 0.6093 - val_acc: 0.8672
Epoch 18/20
20000/20000 [==============================] - 2s 100us/step - loss: 0.0181 - acc: 0.9963 - val_loss: 0.6477 - val_acc: 0.8636
Epoch 19/20
20000/20000 [==============================] - 2s 101us/step - loss: 0.0173 - acc: 0.9963 - val_loss: 0.6985 - val_acc: 0.8608
Epoch 20/20
20000/20000 [==============================] - 2s 100us/step - loss: 0.0137 - acc: 0.9973 - val_loss: 0.7097 - val_acc: 0.8656
history_dict = history.history
history_dict.keys()
dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])
# 绘制训练损失和验证损失
import matplotlib.pyplot as plt

history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']

epochs = range(1, len(loss_values) + 1)

plt.plot(epochs, loss_values, 'bo', label='Traning Loss')
plt.plot(epochs, val_loss_values, 'b', label='Validation Loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

【深度学习】IMDB数据集上电影评论二分类_第1张图片

分析这个图可以看到,训练损失逐渐减小,但是验证损失先降低后提升,所以越往后,对训练数据越优化使得模型过拟合了。所以这里可以先在第三轮之后停止训练。

# 新建一个模型
model2 = models.Sequential()
model2.add(layers.Dense(16, activation='relu', input_shape=(10000,))) # 输入层
model2.add(layers.Dense(16, activation='relu'))
model2.add(layers.Dense(1, activation='sigmoid'))

model2.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])
model2.fit(x_train, y_train, epochs=4, batch_size=512)

results = model2.evaluate(x_test, y_test)

Epoch 1/4
25000/25000 [==============================] - 3s 107us/step - loss: 0.4570 - acc: 0.8204
Epoch 2/4
25000/25000 [==============================] - 2s 92us/step - loss: 0.2620 - acc: 0.9092
Epoch 3/4
25000/25000 [==============================] - 2s 89us/step - loss: 0.2014 - acc: 0.9293
Epoch 4/4
25000/25000 [==============================] - 2s 88us/step - loss: 0.1672 - acc: 0.9402
25000/25000 [==============================] - 2s 69us/step
print(results)
[0.29334229942321777, 0.88356]
res = model2.predict(x_test) # 网络输出的是sigmoid概率
res
array([[0.16456558],
       [0.9996093 ],
       [0.7976105 ],
       ...,
       [0.10431715],
       [0.07220688],
       [0.6143431 ]], dtype=float32)
res.shape
(25000, 1)
res[0]
array([0.16456558], dtype=float32)
res[1]
array([0.9996093], dtype=float32)
y_test[0]
0.0
y_test[1]
1.0

结果解释

我们还记得sigmoid激活函数,如果评论为负面的话,标签为0,所以sigmoid之后的结果是0,只有正面评价才有大于0的sigmoid概率,所以这里的输出结果是对评论为正面的可能性的判断。

# 更换隐藏层单元大小

model3 = models.Sequential()
model3.add(layers.Dense(32, activation='tanh', input_shape=(10000,))) # 输入层
model3.add(layers.Dense(64, activation='tanh'))
model3.add(layers.Dense(1, activation='sigmoid'))

model3.compile(optimizer='rmsprop',
              loss='mse',
              metrics=['accuracy'])
model3.fit(x_train, y_train, epochs=4, batch_size=512)

results = model2.evaluate(x_test, y_test)

Epoch 1/4
25000/25000 [==============================] - 3s 123us/step - loss: 0.1249 - acc: 0.8280
Epoch 2/4
25000/25000 [==============================] - 3s 106us/step - loss: 0.0650 - acc: 0.9140
Epoch 3/4
25000/25000 [==============================] - 3s 105us/step - loss: 0.0526 - acc: 0.9324
Epoch 4/4
25000/25000 [==============================] - 3s 104us/step - loss: 0.0436 - acc: 0.9439
25000/25000 [==============================] - 2s 75us/step
results
[0.29334229942321777, 0.88356]

总结

  • 需要对原始数据进行大量的预处理,使其转换为张量输入到神经网络
  • relu激活函数的Dense层堆是经典模型,能够解决多种问题,比如情感分类等
  • 二分类问题,网络的最后一层只有一个单元,且使用sigmoid激活函数的Dense层,输出尾0~1之间的标量,表示概率值
  • 二分类问题的sigmoid输出,应该使用binary_crossentropy损失函数
  • 不管问题是什么,rmsprop优化器都是足够好的选择
  • 神经网络在训练数据上训练过久,最终会过拟合,使得在未见过的数据上表现越来越差,所以我们需要监控模型在训练集之外的数据上的性能表现

END.

参考:

《Deep Learning with Python》

你可能感兴趣的:(Deep,Learning,Keras)