解决服务器上tensorflow协同的时候,并没有真正的协同,而是将数据打乱的问题

即使queue_runner中shuffle=False 在多线程的循环的时候还是谁快谁先出

感觉是异步的。

 

所以只好fetch,将文件名字也抛出来,来解决wenj文件名字对不上的问题。

"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \
                       --input input_sample.jpg \
                       --output output_sample.jpg \
                       --image_size 128
"""

import tensorflow as tf
import os
from model import CycleGAN
import utils
from glob import glob

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('model', '', 'model path (.pb)')
tf.flags.DEFINE_string('input_dir', 'samples/input_n2v', 'input image path (.jpg)')
tf.flags.DEFINE_string('output_dir', 'samples/output_n2v', 'output image path (.jpg)')
tf.flags.DEFINE_integer('image_size', '128', 'image size, default: 128')

os.environ["CUDA_VISIBLE_DEVICES"] = '1'
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC'
config.gpu_options.allow_growth = True

def get_all_org_files(file_dir):
    L = []
    for root, dirs, files in os.walk(file_dir):
        for file in files:
            if os.path.splitext(file)[1] == '.bmp':
                L.append(os.path.join(root, file))
    return L
# 其中os.path.splitext()函数将路径拆分为文件名+扩展名

def inference():
  graph = tf.Graph()
  # old_img_file_path = get_all_org_files(FLAGS.input_dir)
  
  # !!!改
  model_name = FLAGS.model.split('/')[-1]
  if model_name == 'n2v.pb':
    input_shape = [FLAGS.image_size, FLAGS.image_size, 1]
    input_channel = 1
    input_name = 'input_image_x'
  else:
    input_shape = [FLAGS.image_size, FLAGS.image_size, 3]
    input_channel = 3
    input_name = 'input_image_y'
  
  with graph.as_default():
      # 设置decoder
      # !!!改 linux
      paths = glob("{}/*.{}".format(FLAGS.input_dir, 'bmp'))  # 使用通配符进行文件查找,找到第一个文件名字为jpg就把tf的decode设置为jpg

      tf_decode = tf.image.decode_bmp
      
      

      # 到这里结束,path是已经定义好的,相当于所有的图片名称已经确定了
      # shuffle 默认为true 最好设置为false 否则输入queue的顺序不一定一致
      filename_queue = tf.train.string_input_producer(list(paths), shuffle=False)

      # 读取queue数据的reader
      reader = tf.WholeFileReader()

      # 读取图片数据,这只是定义,并没有真正去读取,在sess中真正读取图片
      filename, data = reader.read(filename_queue)

      # 解码取出到image里面,decode,必须在session里面运行
      image = tf_decode(data, channels=input_channel)  # 使用channels=0是直接使用bmp的channel数据,应该也行,好像不行

      # image是数据的来源
      # reshape这个tensor,将维度改变
      image.set_shape(input_shape)

      image = tf.image.convert_image_dtype(image, dtype=tf.float32)

    # 只能读取一张image的情况,这里果然不能放一个Tensor,这样的做法是错的!!!
    # input = tf.placeholder(dtype=tf.string)
    # with tf.gfile.FastGFile(input, 'rb') as f:
    #   image_data = f.read()
    #   input_image = tf.image.decode_jpeg(image_data, channels=1)
    #   input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
    #   input_image = utils.convert2float(input_image)
    #   input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 1])

      with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(model_file.read())
      # !!!改 ???前面export_graph name改了,这边是不是也要改 是的!
      [output_image] = tf.import_graph_def(graph_def,
                              input_map={input_name: image},
                              return_elements=['output_image:0'],
                              name='output')

  with tf.Session(graph=graph,config=config) as sess:
    coord = tf.train.Coordinator()  # 协同启动的线程
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)  # 启动线程运行队列

    for i in range(len(paths)):
        
        generated, it_path = sess.run((output_image, filename))
        # import pdb;pdb.set_trace()
        it_name = it_path.decode('utf8').split('/')[-1]
        
        with open('{}/{}'.format(FLAGS.output_dir,it_name), 'wb') as f:
              f.write(generated)
    # for i in range(len(paths)):
        # generated = output_image.eval()
        # with open('{}/{:05d}.jpg'.format(FLAGS.output_dir,i+1), 'wb') as f:
              # f.write(generated)

    coord.request_stop()  # 停止所有的线程
    coord.join(threads)

def main(unused_argv):
  inference()

if __name__ == '__main__':
  tf.app.run()

 

你可能感兴趣的:(tf)