data_batch是如何实现的?

1.把图片地址列表、标签列表读入队列

image = tf.train.slice_input_producer([train_list],num_epochs=1,shuffle=False)

train_list可以是任意多个list,可以组合传也可以单独传,num_epochs用来控制整个数据集遍历几次。

2.读取得到的image[0],做图片的预处理。

#2.读取图片并解码
image_train = tf.read_file(image[0])
image_train = tf.image.decode_jpeg(image_train, channels=3)
image_train = tf.image.resize_images(image_train, [208,208])
image_train = tf.cast(image_train, tf.float32) / 255.

3.把图片载入到get_batch中

#3.合并成一个batch,可以传入图片的tensor_list,也可以传入地址、标签的list
img_batch,img_dir = tf.train.batch([image_train,image[0]],
                                           batch_size=10, 
                                           capacity=100,#队列长度
                                           num_threads=2,#线程个数
                                           allow_smaller_final_batch=True)#允许不足的小批次

4.把上面返回的batch_tensor喂进网络里面,得到结果tensor

prediction=inference(img_batch,2)
prediction=tf.nn.softmax(prediction)

5.开启sess,在sess里初始化local_varibel(num_epoch需要)),并开启线程

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())#注意这里要初始化local为了num_epochs

    saver.restore(sess, tf.train.latest_checkpoint('./logs/'))# 恢复权重,run出结果

    coord = tf.train.Coordinator()#开线程
    thread = tf.train.start_queue_runners(sess, coord)
    i=0

    pres=[]
    dirs=[]
    try:
        while not coord.should_stop():#这里的循环会根据你设置的num_epochs自动结束
            i+=1
            pre,dir = sess.run([prediction,img_dir])
            pre = np.argmax(pre, 1)
            dir=list(dir)#对得到的结果,先转为list再处理
            pre=list(pre)

            dirs=dirs+dir
            pres=pres+pre

    except tf.errors.OutOfRangeError:
        dirs=[str(i) for i in dirs]
        pres=[str(i) for i in pres]

        temp=np.array([dirs,pres])
        temp=temp.transpose()
        np.savetxt('test.csv',temp,fmt='%s',delimiter=',')
    finally:
        coord.request_stop()
    coord.join(thread)

 

读取单张图片,多张图片上传到

 

 

你可能感兴趣的:(data_batch是如何实现的?)