这几天看论文、跑实验,然后就被程序中的tf.placeholder难住了,整了好几天,各种报错,各路大神的分享都看了还是百思不得其解,终于搞了几天能正常运行了,记录一下,以免大家跟我一样浪费很多时间,有不对的地方欢迎大家指正!
tensorflow现在是深度学习必备的框架之一,有着强大的功能和完备的社区服务,一直深受广大深度学习爱好者的追捧。可是由于tensorflow1.x和2.x互相之间不兼容甚至差一个小版本都不兼容的现象的存在,再加上文档混乱,经常一个方法可以在不同的包里出现,弄得大家是怨声载道、苦不堪言,本文就记录下楼主遇到的问题。
tf.placeholder它的作用是预先定义好格式,以便后期数据进行处理。如图所示:
这里就定义了一个三维数组,并且数据类型为float32,第三维为3(楼主是做CV的,这里3指的是RGB三通道),定义好以后,可以按照正常的业务流程,依次定义后面需要的中间变量和最终结果:
然后往里面塞数据:
开启session:
最后按照预先定义的数据格式和流程往session里塞数据:
这里的feed_dict就是预先定义的数据格式,放在{}里,不同数据之间用逗号分隔,run后面的第一个参数是输出的数据类型,这里还有一个坑:最后接收数据的变量名不要和第一个参数的相同,不然在第二次进入循环的时候会报错,切记!!!
楼主踩了好几个坑:
刚开始不知道,在刚定义好占位符的时候我的数据就已经塞进去了,然后就是报错:run里面不能是个tensor,我还一直没有发现,傻傻的跑了一天,各种查啊
session.run会根据你定义的数据格式的流程去运行,也就是说只要定义好了以后,它就会自动去找你的数据,楼主输出的数据需要四个参数,然后就各种试参数都是啥,然后各种报错,报完错又各种找资料,最后浪费了大量时间也没弄好
楼主一直没有意识到初始化是个啥、怎么运行以及它的重要性,以至于中间一度程序啥都没问题了初始化报错,而且我根本没有看出来是报的啥错(也可能是一直看代码迷糊了),反正就是各种找啊,最后才发现我好像没有初始化
至于为什么这么重要,我是没有搞清楚,只知道不加它就出错
最后分享一下比较完整的代码片段供大家参考:
def main():
tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
img_holder = tf.placeholder(tf.float32, shape=[None, None, 3])
hei = tf.placeholder(tf.int32)
wid = tf.placeholder(tf.int32)
img = tf.expand_dims(img_holder, 0)
img_v = tf.reduce_max(img, axis=-1, keepdims=True)
img_v = close_op(img_v)
img_i, img_r = rdnet(img_v, img, hei, wid)
img_i = tf.clip_by_value(img_i[0], 0, 1)
img_r = tf.clip_by_value(img_r[0], 0, 1)
print('Loading...')
ckpt = tf.train.latest_checkpoint(rd_dir)
rd_vars = tl.layers.get_variables_with_name('retinex', printable=False)
rd_saver = tf.train.Saver(rd_vars)
rd_saver.restore(sess, ckpt)
img_files = os.listdir(in_dir)
img_num = len(img_files)
avg_time = 0
for img_file in img_files:
in_img = Image.open(in_dir + img_file).convert("RGB")
assert in_img is not None
w = in_img.size[0]
h = in_img.size[1]
in_img = np.array(in_img) / 255
img_i1, img_r1 = sess.run([img_i, img_r], feed_dict={img_holder: in_img, hei: h, wid: w})
out_name = img_file.split('.', 1)[0] + '.jpg'
img_i1 = img_i1[:, :, 0]
img_i1 = Image.fromarray(np.uint8(img_i1 * 255))
img_i1.save(out_dir + out_name)