import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import numpy as np
from keras.datasets import imdb
from keras import models
from
(train_data, train_lables), (test_data, test_lables) = imdb.load_data(num_words=10000)
def vectorize_sequences(sequences, dim=10000):
"""
fulfill ont-hot encode
:param squences:
:param dim:
:return:
"""
results = np.zeros((len(sequences), dim))
for i, sequences in enumerate(sequences):
results[i, sequences] = 1.
return results
#将数据进行one-hot编码转化为向量
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
y_train = np.asarray(train_lables).astype("float32")
y_test = np.asarray(test_lables).astype("float32")
``
```python
#模型定义
model = models.Sequential()
model.add(layers.Dense(16, activation="relu", input_shape=(10000,)))
model.add(layers.Dense(16, activation="relu"))
model.add(layers.Dense(1, activation="sigmoid"))
``
```python
#划分出验证集和训练集
x_val = x_train[:10000]
partial_x_train = x_train[10000:]
y_val = y_train[:10000]
partial_y_train = y_train[10000:]
#编译模型
model.compile(
optimizer="rmsprop",
loss="binary_crossentropy",
metrics=["accuracy"]
)
#训练模型
history = model.fit(partial_x_train,
partial_y_train,
epochs=10,
batch_size=512,
validation_data=(x_val, y_val))
history_dict = history.history
#绘制训练损失和验证损失
import matplotlib.pyplot as plt
loss_values = history_dict["loss"]
val_loss_values = history_dict["val_loss"]
epchos = [i for i in range(1, len(loss_values)+1)]
plt.plot(epchos, loss_values, "bo", label="Train_loss")
plt.plot(epchos, val_loss_values, "b", label="Validation_loss")
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()
#绘制训练精确度和验证精确度
acc = history_dict["accuracy"]
val_acc = history_dict["val_accuracy"]
plt.plot(epchos, acc, "bo", label="Train_acc")
plt.plot(epchos, val_acc, "b", label= "Validation_acc")
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
#预测测试数据
model.predict(x_test)
array([[0.0365333 ],
[0.99999344],
[0.7845194 ],
…,
[0.01798147],
[0.01858079],
[0.6560805 ]], dtype=float32)