tensorflow TypeError: run() got multiple values for argument 'feed_dict'

tensorflow在session调用run()时,报上述错误。

import tensorflow as tf
sess = tf.InteractiveSession()

import numpy as np


a = np.array([[1.0,2.0,3.0,4.0],[5.0,6.0,7.0,8.0],[9.0,10.0,11.0,12.0],[1.0,1.0,1.0,1.0]])
w = np.ones([3.0,3.0,1.0,1.0])

W_conv1 = tf.Variable(w)

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')


x = tf.placeholder(tf.float64, shape=[4,4])

x_image = tf.reshape(x,[1,4,4,1])

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))

sess.run(tf.initialize_all_variables())

i,h1 = sess.run(x_image,h_conv1, feed_dict={x:a})

解决办法:

i,h1 = sess.run([x_image, h_conv1], feed_dict={x:a})

原因:
session.run()的参数如下:

run(
    fetches,
    feed_dict=None,
    options=None,
    run_metadata=None
)

各个参数的含义:
Args:
fetches: A single graph element, a list of graph elements, or a dictionary whose values are graph elements or lists of graph elements (described above).
feed_dict: A dictionary that maps graph elements to values (described above).
options: A [RunOptions] protocol buffer
run_metadata: A [RunMetadata] protocol buffer

x_image,h_conv1应该是fetches参数,要以list给出,不然会解析为两个参数。

参考:https://www.tensorflow.org/api_docs/python/tf/Session

你可能感兴趣的:(机器学习)