Keras是一个深度学习框架,与PyTorch不同的是,Keras提供高层次的模型构建模块,不处理张量操作,求微分等逻辑。而PyTorch中用户可以自己定义网络前向和反向传播的逻辑。而这些底层的操作,Keras则依赖其他后端来实现,在Keras中可以选择TensorFlow或Theano作为后端。而Theano现在已经停止维护了,所以还是用TensorFlow作为计算后端,这也是Keras中默认的计算后端。使用Keras可以简单、迅速构建网络。在使用过程中会发现Keras代码比原生TensorFlow代码更简洁。
在Keras中有两种定义网络层的方式,与PyTorch一样,我们先来看第一种方式。
这种方式会将定义的网络层按顺序连接起来,输入数据依次经过每一层进行计算,如下:
from keras import models
from keras import layers
model = models.Sequential()
model.add(layers.Dense(32, activation='relu', input_shape=(784,)))
model.add(layers.Dense(10, activation='softmax'))
而在PyTorch中也有类似的操作:
import torch.nn as nn
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
然而不同的是,在PyTorch中还需要自己定义前向传播:
def forward(self, x):
x = self.classifier(x)
return x
这种方式定义模型,我们就能够看出输入数据传入各个层进行计算:
input_tensor = layers.Input(shape=(784,))
x = layers.Dense(32, activation='relu')(input_tensor)
output_tensor = layers.Dense(10, activation='softmax')(x)
model = models.Model(inputs=input_tensor, outputs=output_tensor)
同样,PyTorch中也有类似的操作:
import torch.nn as nn
import torch.nn.functional as F
self.linear1 = nn.Linear(512 * 7 * 7, 4096)
self.linear2 = nn.Linear(4096, 4096)
self.linear3 = nn.Linear(4096, num_classes)
定义好各个层之后,接下来是前向传播过程:
def forward(self, x):
x = F.relu(self.linear1(x))
x = F.dropout(x)
x = F.relu(self.linear2(x))
x = F.dropout(x)
x = self.linear3(x)
return x
注意,在PyTorch中,nn.relu和nn.dropout等层,如果要在forward中使用,要换成F.relu和F.dropout。
from keras import optimizers
model.compile(optimizer=optimizers.RMSprop(lr=0.001),loss='mse',metrics=['accuracy'])
在编译模型的时候,定义好使用的优化器,损失函数,以及需要计算的模型指标,这样模型就可以进行训练了。
在Keras中,模型训练的函数名与Scikit learn中的一样,调用fit即可开始训练:
model.fit(input_tensor, target_tensor, batch_size=128, epochs=10)
如果碰到类别数据不平衡的情况,可能需要在训练的时候设置类别的权重,那么可以通过以下的方式:
class_weight = {0: 1.,
1: 50.,
2: 2.}
model.fit(input_tensor, target_tensor, batch_size=128, epochs=10, class_weight=class_weight)
现在大家对Keras的基本用法有了一个初步了解,接下来将通过一个例子来构建一个完整的分类应用。
这里将使用imdb电影数据集,使用电影评论信息将其分成正面或负面。
首先读取数据:
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)
print(train_data[0])
print(train_labels[0])
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
1
数据中记录的是单词的索引,可以把它映射回原来的词:
word_index = imdb.get_word_index()
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
decoded_review = ' '.join([reverse_word_index.get(i - 3, '?') for i in train_data[0]])
print(decoded_review)
? this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert ? is an amazing actor and now the same being director ? father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for ? and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also ? to the two little boy's that played the ? of norman and paul they were just brilliant children are often left out of the ? list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
然后需要将数据向量化:
import numpy as np
def vectorize_sequences(sequences, dimension=10000):
results = np.zeros((len(sequences), dimension))
for i, sequence in enumerate(sequences):
results[i, sequence] = 1. # set specific indices of results[i] to 1s
return results
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
y_train = np.asarray(train_labels).astype('float32')
y_test = np.asarray(test_labels).astype('float32')
接下来我们构建一个三层的神经网络,并传入数据进行训练:
model = models.Sequential()
model.add(layers.Dense(16, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])
x_val = x_train[:10000]
partial_x_train = x_train[10000:]
y_val = y_train[:10000]
partial_y_train = y_train[10000:]
history = model.fit(partial_x_train,
partial_y_train,
epochs=20,
batch_size=512,
validation_data=(x_val, y_val))
我们画出训练、验证损失和准确率:
从图上可以看出模型很快就过拟合了,因此后面的训练没有意义,我们可以让网络提前停止训练,使用EarlyStopping可以让网络在指定epoch数内,如果loss没有减少,那么就停止训练,这里设置5:
from keras.callbacks import EarlyStopping
earlystop = EarlyStopping(monitor='val_loss', patience=5, verbose=1)
history = model.fit(partial_x_train,
partial_y_train,
epochs=20,
batch_size=512,
validation_data=(x_val, y_val),
callbacks=[earlystop])
从训练日志中可以看到,模型在第9个epoch时停止训练了:
Epoch 00009: early stopping
我们再来看看训练、验证损失和准确率:
接下来是模型评估、预测与保存:
results = model.evaluate(x_test, y_test)
print(results)
predictions = model.predict(x_test)
predictions = predictions.argmax(1)
print(predictions)
model.save('classify_movie.h5')
扫码关注微信公众号:机器工匠,回复关键字“movie”获取代码和数据。