我的第一个DL

以下代码来自《Python深度学习》

#Author:KXG
# 加载imdb数据集
from keras.datasets import imdb
# 仅保留训练数据前10000个最常出现的单词,低频单词将被舍弃
(train_data,train_lables),(test_data,test_lables)=imdb.load_data(path='./imdb.npz',num_words=10000)#
# print(train_data[0])
#将整数序列编码为二进制矩阵
import numpy as np
def vectorize_sequences(sequences,dimension=10000):
    results=np.zeros((len(sequences),dimension))
    for i,sequence in enumerate(sequences):
        results[i,sequence]=1.
    return results
x_train=vectorize_sequences(train_data)
y_train=vectorize_sequences(test_data)
#标签向量化
y_train=np.asarray(train_lables).astype('float32')
y_test=np.asarray(test_lables).astype('float32')
# 模型定义
from keras import models
from keras import layers
model=models.Sequential()
model.add(layers.Dense(16,activation='relu',input_shape=(10000,)))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))

# 留出验证集
x_val=x_train[:10000]
partial_x_train=x_train[10000:]
y_val=y_train[:10000]
partial_y_train=y_train[10000:]
# 训练模型
# 编译模型
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])
history=model.fit(partial_x_train,
                  partial_y_train,
                  epochs=20,
                  batch_size=512,
                  validation_data=(x_val,y_val))
# 绘制训练损失和验证损失
import matplotlib.pyplot as plt
history_dict=history.history
# print(history_dict.keys())
loss_values=history_dict['loss']
val_loss_values=history_dict['val_loss']

epochs=range(1,len(loss_values)+1)
plt.plot(epochs,loss_values,'bo',label='Training loss')
plt.plot(epochs,val_loss_values,'b',label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
# 绘制训练精度和验证精度
plt.clf()#清空图像
acc=history_dict['accuracy']
val_acc=history_dict['val_accuracy']
plt.plot(epochs,acc,'bo',label='Training acc')
plt.plot(epochs,val_acc,'b',label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()#显示图例
plt.show()

 

你可能感兴趣的:(python深度学习,深度学习,python)