pickle包可以用于各类模型的保存和读取,比如sklearn和keras里的所有模型。补充:pickle包也可以用于字典、数据集的保存和读取。
import pickle
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
'''
-------------此处省略模型的训练步骤--------------
'''
#创建一个pickle文件并命名model.pickle,注意后缀不要漏了
with open('model.pickle', 'wb') as f:
#把模型倒入文件中,dump可以说很形象生动了~
pickle.dump(model, f)
with open('model.pickle', 'rb') as f:
model = pickle.load(f)
保存为h5格式的文件
'''
经过一系列复杂的定义和训练得到了训练好的model
'''
model.save('model.h5')
这样保存的模型结果,它既保持了模型的结构,又保存了模型的参数。
from keras.models import load_model
model = load_model('model.h5')
如果仅仅想保存模型训练好得到的参数(w, b),可以用model.save_weights()。
'''
经过一系列复杂的定义和训练得到了训练好的model
'''
model.save_weights('model_weights.h5')
注意这里要先获得一个有结构的空白模型,而且这个空白模型的结构比如和我之前model.save_weights(‘model_weights.h5’) 的模型结构一模一样。
def create_model():
.......
model = create_model()
model.load_weights('model3.h5')
读取某一层的参数:
"""
假如原模型为:
model = Sequential()
model.add(Dense(2, input_dim=3, name="dense_1"))
model.add(Dense(3, name="dense_2"))
...
model.save_weights(fname)
"""
# new model
model = Sequential()
model.add(Dense(2, input_dim=3, name="dense_1")) # will be loaded
model.add(Dense(10, name="new_dense")) # will not be loaded
# load weights from first model; will only affect the first layer, dense_1.
model.load_weights(fname, by_name=True)
附:通过 model.summary() 查看模型结构
model.summary()
得到:
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_22 (Dense) (None, 512) 401920
_________________________________________________________________
activation_21 (Activation) (None, 512) 0
_________________________________________________________________
dense_23 (Dense) (None, 256) 131328
_________________________________________________________________
activation_22 (Activation) (None, 256) 0
_________________________________________________________________
dropout_5 (Dropout) (None, 256) 0
_________________________________________________________________
dense_24 (Dense) (None, 256) 65792
_________________________________________________________________
activation_23 (Activation) (None, 256) 0
_________________________________________________________________
dense_25 (Dense) (None, 10) 2570
_________________________________________________________________
activation_24 (Activation) (None, 10) 0
=================================================================
Total params: 601,610
Trainable params: 601,610
Non-trainable params: 0
_________________________________________________________________