【python】tensorflow框架下sess.run()读取数据卡住---解决方案

最近在tensorflow框架下调试代码时遇到sess.run()读取数据卡住的情况,搜索尝试了许多方法终于找到解决方案,希望遇到同样问题的小伙伴能尽快解决问题。

问题描述:img, label = sess.run([image, labels])

程序运行到改行进行数据读取时,不报错,但是程序卡在这条语句,无法往下执行

解决方案:在with tf.session() as sess: 部分加如下语句

coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)

详细解释:

image—图片保存路径列表(string类型);
labels—标签列表(int32类型);


由于数据量太大,因此在对图片数据我们通常采用批量处理的方式,也即一次处理batch_size(最小可以取为1)张图片,减轻CPU的负担,保证程序的正常运行。

而在此过程当中,数据经历的两次变换(现将原类型string / int32 均记为 numpy ):

numpy — tensor — numpy

所述的问题就是在第二次变换(由tensor转为numpy)时出现的


(1)由numpy转为tensor代码

def get_batch(image, label, image_W, image_H, batch_size, capacity):
    '''
    Args:
        image: list type
        label: list type
        image_W: image width
        image_H: image height
        batch_size: batch size
        capacity: the maximum elements in queue
    Returns:
        image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32
        label_batch: 1D tensor [batch_size], dtype=tf.int32
    '''
    # image_W, image_H, :设置好固定的图像高度和宽度
    # 设置batch_size:每个batch要放多少张图片
    # capacity:一个队列最大多少

    image = tf.cast(image, tf.string)
    label = tf.cast(label, tf.int32)

    input_queue = tf.train.slice_input_producer([image, label])

    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0])  # read img from a queue

    # step2:将图像解码,不同类型的图像不能混在一起,要么只用jpeg,要么只用png等。
    image = tf.image.decode_png(image_contents, channels=3)
    # image = tf.image.resize_images(image,(image_W,image_H),0)

    # step3:数据预处理,对图像进行旋转、缩放、裁剪、归一化等操作,让计算出的模型更健壮。
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
    image = tf.image.per_image_standardization(image)

    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size=batch_size,
                                              num_threads=64,
                                              capacity=capacity)

    # 重新排列label,行数为[batch_size]
    label_batch = tf.reshape(label_batch, [batch_size])
    image_batch = tf.cast(image_batch, tf.float32)

    return image_batch, label_batch

由于在进行网络训练时,喂入feed_dict的数据类型是已经确定了的(这就是占位符placeholder做的事),
所以batch_images, batch_labels = sess.run([train_batch, train_label_batch]) 的目的就是将tensor重新读取为其原始类型的数据,得到的(batch_images, batch_labels)可以直接喂入feed_dict。


(2)tensor转为numpy代码

	train, train_label = read_data.get_files(train_dir)

    train_batch, train_label_batch = read_data.get_batch(train,
                                                             train_label,
                                                             IMG_W,
                                                             IMG_H,
                                                             BATCH_SIZE,
                                                             CAPACITY)
    with tf.Session(config=config) as sess:
    # with tf.Session(config=tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)) as sess:
        # init = tf.initialize_all_variables()
        init = tf.global_variables_initializer()
        sess.run(init)
     	coord = tf.train.Coordinator()
		thread = tf.train.start_queue_runners(sess, coord)
  
        writer = tf.summary.FileWriter(save_dir, sess.graph)
        sample_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))

        print("start_sess")

        for step in range(MaxStep):

            batch_images, batch_labels = sess.run([train_batch, train_label_batch])
     
            batch_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
          

            # 更新 D 的参数
            _, summary_str = sess.run([d_optim, d_sum],
                                      feed_dict={images: batch_images,
                                                 z: batch_z,
                                                 y: batch_labels})

很好理解了,[train_batch, train_label_batch]是tensor表示,image中存储的是所有图片路径,这里用到了队列结构来完成。

  • 在运行块中要先start这个队列,让它工作起来,数据才能正常读取
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)

这也是这个问题的核心所在。

你可能感兴趣的:(【python】tensorflow框架下sess.run()读取数据卡住---解决方案)