IMDB电影评论分类

import os
import keras

import numpy as np

from keras import  models, layers
from keras import optimizers
from keras.datasets import imdb

(train_data, train_labels), (test_data, test_labels) = imdb.load_data(path=os.getcwd()+'/imdb.npz', num_words=10000)
print(train_data.shape)
print(train_labels.shape)


def vectorize(seqs, dim =10000):
    ret = np.zeros((len(seqs), dim))
    for i, seq in enumerate(seqs):
        ret[i, seq] = 1
    return ret


x_train = vectorize(train_data)
x_test = vectorize(test_data)

y_train = np.asarray(train_labels).astype('float32')
y_test = np.asarray(test_labels).astype('float32')

print(x_train.shape)
print(y_train.shape)
print(x_train[0])
print(y_train[0])

model =models.Sequential()
model.add(layers.Dense(16, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()

keras.utils.plot_model(model, to_file='imdb.png')


x_val = x_train[:10000]
partial_x_train = x_train[10000:]
y_val = y_train[:10000]
partial_y_train = y_train[10000:]

model.compile(loss='binary_crossentropy',
              optimizer=optimizers.RMSprop(lr=0.001),   # 学习率0.001
              metrics=['accuracy'])

history = model.fit(partial_x_train,
                    partial_y_train,
                    batch_size=128,
                    epochs=10,
                    verbose=1,
                    validation_data=(x_val, y_val))

print(history.history.keys())
history_dict = history.history

loss = history_dict['loss']
val_loss = history_dict['val_loss']

epochs = range(1, len(loss) + 1)

import matplotlib.pyplot as plt
plt.plot(epochs, loss, 'bo', label='Training loss')     # 指定x = epochs, y = loss,bo为线的类型,label为这条线的标签
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')   # 这张图的名字
plt.xlabel('Epochs')    # 横坐标名字
plt.ylabel('Loss')
plt.legend()    # 给图加上图例

plt.show()

acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

model.save('imdb.h5')

有些关键词因为版本的原因写法不同:

accuracy/acc

metrics/metrice

你可能感兴趣的:(计算机视觉)