import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
model_name = 'rygh_logistic_save.h5'
# print(data.iloc[:, -1].value_counts()) 这里可知,y都是-1,1的数据,明显是个二分类问题,但是我们需要将-1替换成0
data = pd.read_csv('./datas/rygh/credit-a.csv')
print(data.head())
x = data.iloc[:, : -1]
y = data.iloc[:, -1].replace(-1, 0)
try:
model = keras.models.load_model('./models_lei/{model_name}'.format(model_name=model_name))
print('Successfully reloaded the model...')
except:
print('The model does not exist, we have to train a new one...')
model = keras.Sequential()
model.add(keras.layers.Dense(4, input_shape=(15, ), activation='relu'))
model.add(keras.layers.Dense(4, activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
# print(model.summary())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['acc']
)
model.fit(x, y, epochs=100)
# history = model.fit(x, y, epochs=100)
# 查看损失和正确率的曲线
# print(history.history.keys()) #dict_keys(['loss', 'acc'])
# plt.plot(history.epoch, history.history.get('loss'))
# plt.plot(history.epoch, history.history.get('acc'))
# plt.show()
model.save('./models_lei/{model_name}'.format(model_name=model_name))
test = data.iloc[:5, :-1]
real = data.iloc[:5, -1]
predict = model.predict(test)
print('real:', real)
print('predict:', predict)
在搞keras的时候出现了一个问题,当我试图启动已经保存的模型的时候,即使打断点调试,
print('Successfully reloaded the model...')
上面这一句代码始终没有执行,说明我启动这个模型失败了。搞了半天也没找到问题的原因。
然后我把try和except注释掉了,运行:
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
model_name = 'rygh_logistic_save.h5'
# print(data.iloc[:, -1].value_counts()) 这里可知,y都是-1,1的数据,明显是个二分类问题,但是我们需要将-1替换成0
data = pd.read_csv('./datas/rygh/credit-a.csv')
print(data.head())
x = data.iloc[:, : -1]
y = data.iloc[:, -1].replace(-1, 0)
model = keras.models.load_model('./models_lei/{model_name}'.format(model_name=model_name))
test = data.iloc[:5, :-1]
real = data.iloc[:5, -1]
predict = model.predict(test)
print('real:', real)
print('predict:', predict)
有人说
我的tensorflow版本是2.1.0,当我查看我已被迫安装的h5py时,发现它的版本真的是3.1.0,所以把3.1.0版本的h5py先删除掉,再添加2.10.0版本的h5py。
然后运行,没问题了: