首先贴出来训练部分的代码:
def train(self, train_generator, validation_generator, pre_model_path=None):
'''
:param train_generator: 训练集
:param validation_generator: 测试集
:param pre_model_path: 预训练模型,在之前模型上继续训练,目前仅支持h5模型
'''
# 在已有模型基础上继续训练
if pre_model_path:
self.model = load_model(pre_model_path)
# 配置模型
with open(pjoin(TXT_DIR, 'message.txt'), 'r') as f:
_, TRAIN_SIZE, VAL_SIZE, _ = list(map(int, f.readline().split(',')))
STEP_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE + 1
VALIDATION_STEPS = VAL_SIZE // BATCH_SIZE + 1
optimizer = optimizers.RMSprop(lr=LEARNING_RATE)
self.model.compile(loss='mse', optimizer=optimizer, metrics=['mae'])
# 训练
self.history = self.model.fit(train_generator, steps_per_epoch=STEP_PER_EPOCH, epochs=EPOCH,
validation_data=validation_generator, validation_steps=VALIDATION_STEPS)
# 保存
self.save()
return self.model
在fit之后调用了模型保存:
def save(self):
if os.path.exists(MODEL_SAVE_DIR):
shutil.rmtree(MODEL_SAVE_DIR)
os.mkdir(MODEL_SAVE_DIR)
# 保存h5模型
h5_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5')
self.model.save(h5_path)
print('成功保存h5模型:%s' % h5_path)
# 保存pb模型
# 定义输入输出
model_signature = predict_signature_def(inputs={INPUT_KEY: self.model.input},
outputs={OUTPUT_KEY: self.model.output})
with tf.keras.backend.get_session() as sess:
pb_path = pjoin(MODEL_SAVE_DIR, VERSION + '_pb')
try:
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder = tf.saved_model.builder.SavedModelBuilder(pb_path)
builder.add_meta_graph_and_variables(sess, [MODEL_TAG],
clear_devices=True,
signature_def_map={SIGNATURE_DEF_KEY: model_signature},
legacy_init_op=legacy_init_op)
builder.save()
print('成功保存PB模型:%s' % pb_path)
except Exception as e:
print("Fail to export saved model, exception: {}".format(e))
这样训练过程中只能在fit之后才能对模型进行保存,也就是所有的epoch执行完之后才能保存我们训练好的参数,暂时还没有找到其他解决方法,就说一下我的解决办法。
想要不中断训练代码,还要退出训练对模型进行保存,这些操作就必须在一个大的session中完成,需要将其中的session抽离出来,不然不能循环调用,因为执行完一次session,系统会自动关掉之前建立的session,所以我们现在外围建立一个session,作为参数传入函数中,按理来说sess也可以建立在while True外围,但是这样所有的操作都在一个sess里面完成,会耗费很多内存,尤其是长时间训练,所以将sess放在了while True里面,这样的话每次调用都会释放内存,然后继续训练,会解决耗费内存的问题,具体操作如下 :
import tensorflow.keras.backend as backend
while True:
net.train(train_generator, val_generator)
backend.clear_session()
将保存的模型的.h5文件路径加载至train中,以便于调用上次训练好的模型,加入判断模型是否存在的语句:os.path.exists(model_path)来判断是否存在训练好的模型,若是有那么就载入,在现有模型的基础上进行训练。
h5_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5'),代码如下:
def train(self, train_generator, validation_generator, sess, pre_model_path=None):
'''
:param train_generator: 训练集
:param validation_generator: 测试集
:param pre_model_path: 预训练模型,在之前模型上继续训练,目前仅支持h5模型
'''
# **********************修改的部分*************************
pre_model_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5')
# ********************************************************
# 在已有模型基础上继续训练
if os.path.exists(pre_model_path):
self.model = load_model(pre_model_path)
# 配置模型
with open(pjoin(TXT_DIR, 'message.txt'), 'r') as f:
_, TRAIN_SIZE, VAL_SIZE, _ = list(map(int, f.readline().split(',')))
STEP_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE + 1
VALIDATION_STEPS = VAL_SIZE // BATCH_SIZE + 1
optimizer = optimizers.RMSprop(lr=LEARNING_RATE)
self.model.compile(loss='mse', optimizer=optimizer, metrics=['mae'])
# 训练
self.history = self.model.fit(train_generator, steps_per_epoch=STEP_PER_EPOCH, epochs=EPOCH,
validation_data=validation_generator, validation_steps=VALIDATION_STEPS)
# 保存
self.save(sess)
return self.model
def save(self, sess):
if os.path.exists(MODEL_SAVE_DIR):
shutil.rmtree(MODEL_SAVE_DIR)
os.mkdir(MODEL_SAVE_DIR)
# 保存h5模型
h5_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5')
self.model.save(h5_path)
print('成功保存h5模型:%s' % h5_path)
# 保存pb模型
# 定义输入输出
#model_signature = predict_signature_def(inputs={INPUT_KEY: self.model.input},
outputs={OUTPUT_KEY: self.model.output})
#pb_path = pjoin(MODEL_SAVE_DIR, VERSION + '_pb')
# try:
# legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
# builder = tf.saved_model.builder.SavedModelBuilder(pb_path)
# builder.add_meta_graph_and_variables(sess, [MODEL_TAG],
clear_devices=True,
signature_def_map={SIGNATURE_DEF_KEY: model_signature},
legacy_init_op=legacy_init_op)
# builder.save()
# print('成功保存PB模型:%s' % pb_path)
# except Exception as e:
# print("Fail to export saved model, exception: {}".format(e))