多线程使用keras训练模型错误-"is not an element of this graph"

错误场景

为了每隔固定时间训练一次模型, luffy在线程函数中设置timer再次调用线程函数, 简化版代码如下

def _thread_func(interval=10):
	model = train_model()				
	timer = threading.Timer(interval, _thread_func, args=(interval))		#设置定时器间隔interval后再次调用_thread_func, 无限循环	
	timer.start()
			
def train():
	x_train, y_train = get_dataset()		#获取训练集
	
	model = build_model()					#构建模型的输入输出和中间层, 不详细展开
	model.compile("adam", loss="mse")
	model.fit(x_train, y_train, batch_size=32, epochs=10)
	
	return model

错误内容

TypeError: Can not interpret feed_dict key as Tensor: Tensor Tensor("func:0", shape=(?,?), dtype=int32) is not an element of this graph.

原因分析

 查找资料后发现, keras是基于tensorflow的(luffy用的backend是tensorflow), 而tensorflow有两个关键的对象, 计算图(Graph)和会话(Session). 以luffy的智商进行了理解就是: Graph和Session两者通常是一一对应的.
 放到luffy碰到的问题中:

  • _thread_func1第一次执行时, 变量是定义在默认的计算图上的, 我们称之为default_graph,第一次训练, 也是运行在默认会话上的, 我们称之为default_session.
  • 隔了一段时间后, _thread_func1调用了自己, 新开辟了一个线程_thread_func2, 此时在_thread_func2也调用了train()函数, 但是此时变量是定义在新的graph上的, 我们称之为new_graph, 但是进程并没有切换新的session, 用的还是default_session, 所以就出错啦.

代码修正

非常简单, 加两行代码即可, 保证计算图和会话一致.

def train():
	with tf.Graph().as_default():				#第一行增加代码
		x_train, y_train = get_dataset()		#获取训练集
		
		model = build_model()						#构建模型的输入输出和中间层, 不详细展开
		model.compile("adam", loss="mse")
		with tf.Session() as sess:				#第二行增加代码
			model.fit(x_train, y_train, batch_size=32, epochs=10)
		
		return model

你可能感兴趣的:(tensorflow)