Tensorflow报错:
TypeError: Fetch argument None has invalid type
File "D:/Workspace/SpyderWorkspace/Models/RBM_mnist.py", line 211, in main
sess.run(rbm.train_ops(k, step, i),feed_dict={X: X_batch})
File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 767, in run
run_metadata_ptr)
File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 952, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 408, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 227, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type
构建图的代码中,RBM是一个类,train_ops是该类下的函数。rbm是类RBM的一个对象,调用train_ops函数。
train_ops函数定义如下:
def train_ops(self, k, step, batch_count):
...(省略)
with tf.name_scope("update_params"):
tf.assign(self.W, self.W + self.learning_rate * delta_W,name="update_W")
tf.assign(self.bv, self.bv + self.learning_rate*delta_bv,name="update_bv")
tf.assign(self.bh, self.bh + self.learning_rate*delta_bh,name="update_bh")
函数的主要功能是更新参数,于是定义了更新参数操作。而报错的原因在于定义的train_ops函数没有返回值。
正确的代码如下:
def train_ops(self, k, step, batch_count):
...(省略)
with tf.name_scope("update_params"):
new_W = tf.assign(self.W, self.W + self.learning_rate * delta_W,name="update_W")
new_bv = tf.assign(self.bv, self.bv + self.learning_rate*delta_bv,name="update_bv")
new_bh = tf.assign(self.bh, self.bh + self.learning_rate*delta_bh,name="update_bh")
return (new_W, new_bv, new_bh)
至此,报错解决。