TypeError: Fetch argument None has invalid type 解决方案

Tensorflow报错:

TypeError: Fetch argument None has invalid type 

报错完整内容:

  File "D:/Workspace/SpyderWorkspace/Models/RBM_mnist.py", line 211, in main
    sess.run(rbm.train_ops(k, step, i),feed_dict={X: X_batch})

  File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 767, in run
    run_metadata_ptr)

  File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 952, in _run
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)

  File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 408, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)

  File "D:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 227, in for_fetch
    (fetch, type(fetch)))

TypeError: Fetch argument None has invalid type 

发现是从sess.run()那行报错。

构建图的代码中,RBM是一个类,train_ops是该类下的函数。rbm是类RBM的一个对象,调用train_ops函数。


train_ops函数定义如下:

def train_ops(self, k, step, batch_count):
        ...(省略)
            with tf.name_scope("update_params"):
                tf.assign(self.W, self.W + self.learning_rate * delta_W,name="update_W")
                tf.assign(self.bv, self.bv + self.learning_rate*delta_bv,name="update_bv")
                tf.assign(self.bh, self.bh + self.learning_rate*delta_bh,name="update_bh")
函数的主要功能是更新参数,于是定义了更新参数操作。而报错的原因在于定义的train_ops函数没有返回值。

正确的代码如下:

def train_ops(self, k, step, batch_count):
        ...(省略)
            with tf.name_scope("update_params"):
                new_W = tf.assign(self.W, self.W + self.learning_rate * delta_W,name="update_W")
                new_bv = tf.assign(self.bv, self.bv + self.learning_rate*delta_bv,name="update_bv")
                new_bh = tf.assign(self.bh, self.bh + self.learning_rate*delta_bh,name="update_bh")
            return (new_W, new_bv, new_bh)
至此,报错解决。

你可能感兴趣的:(Python,tensorflow)