tensorflow的placeholder踩坑

问题描述

这几天看论文、跑实验,然后就被程序中的tf.placeholder难住了,整了好几天,各种报错,各路大神的分享都看了还是百思不得其解,终于搞了几天能正常运行了,记录一下,以免大家跟我一样浪费很多时间,有不对的地方欢迎大家指正!

正文

tensorflow现在是深度学习必备的框架之一,有着强大的功能和完备的社区服务,一直深受广大深度学习爱好者的追捧。可是由于tensorflow1.x和2.x互相之间不兼容甚至差一个小版本都不兼容的现象的存在,再加上文档混乱,经常一个方法可以在不同的包里出现,弄得大家是怨声载道、苦不堪言,本文就记录下楼主遇到的问题。
tf.placeholder它的作用是预先定义好格式,以便后期数据进行处理。如图所示:预定义图片数据格式
这里就定义了一个三维数组,并且数据类型为float32,第三维为3(楼主是做CV的,这里3指的是RGB三通道),定义好以后,可以按照正常的业务流程,依次定义后面需要的中间变量和最终结果:
预定义变量数据类型
然后往里面塞数据:
加载数据开启session:
tensorflow的placeholder踩坑_第1张图片
最后按照预先定义的数据格式和流程往session里塞数据:
运行session
这里的feed_dict就是预先定义的数据格式,放在{}里,不同数据之间用逗号分隔,run后面的第一个参数是输出的数据类型,这里还有一个坑:最后接收数据的变量名不要和第一个参数的相同,不然在第二次进入循环的时候会报错,切记!!!
楼主踩了好几个坑:

定义了占位符之后就塞数据

刚开始不知道,在刚定义好占位符的时候我的数据就已经塞进去了,然后就是报错:run里面不能是个tensor,我还一直没有发现,傻傻的跑了一天,各种查啊

没有意识到session.run它的运行流程

session.run会根据你定义的数据格式的流程去运行,也就是说只要定义好了以后,它就会自动去找你的数据,楼主输出的数据需要四个参数,然后就各种试参数都是啥,然后各种报错,报完错又各种找资料,最后浪费了大量时间也没弄好

tf.session.run需要初始化

楼主一直没有意识到初始化是个啥、怎么运行以及它的重要性,以至于中间一度程序啥都没问题了初始化报错,而且我根本没有看出来是报的啥错(也可能是一直看代码迷糊了),反正就是各种找啊,最后才发现我好像没有初始化

tf.clip_by_value很重要

至于为什么这么重要,我是没有搞清楚,只知道不加它就出错

最后分享一下比较完整的代码片段供大家参考:

def main():
    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    img_holder = tf.placeholder(tf.float32, shape=[None, None, 3])
    hei = tf.placeholder(tf.int32)
    wid = tf.placeholder(tf.int32)
    img = tf.expand_dims(img_holder, 0)
    img_v = tf.reduce_max(img, axis=-1, keepdims=True)
    img_v = close_op(img_v)

  
    img_i, img_r = rdnet(img_v, img, hei, wid)
    img_i = tf.clip_by_value(img_i[0], 0, 1)
    img_r = tf.clip_by_value(img_r[0], 0, 1)
    print('Loading...')
    ckpt = tf.train.latest_checkpoint(rd_dir)
    rd_vars = tl.layers.get_variables_with_name('retinex', printable=False)
    rd_saver = tf.train.Saver(rd_vars)
    rd_saver.restore(sess, ckpt)

    img_files = os.listdir(in_dir)
    img_num = len(img_files)
    avg_time = 0

    for img_file in img_files:
        in_img = Image.open(in_dir + img_file).convert("RGB")
        assert in_img is not None
        w = in_img.size[0]
        h = in_img.size[1]
        in_img = np.array(in_img) / 255
        img_i1, img_r1 = sess.run([img_i, img_r], feed_dict={img_holder: in_img, hei: h, wid: w})
        out_name = img_file.split('.', 1)[0] + '.jpg'
        img_i1 = img_i1[:, :, 0]
        img_i1 = Image.fromarray(np.uint8(img_i1 * 255))
        img_i1.save(out_dir + out_name)

你可能感兴趣的:(python,tensorflow,python)