tf.image.convert_image_dtype(image, dtype)踩坑注意

前几天把网上找的一个检测抓取框的代码改成ROS node发现结果一直很离谱,抓取框坐标比图片还大,今天突然想起来又看了看代码,发现了问题出在tf.image.convert_image_dtype(image, dtype)这个函数在进行类型转换时自动的scale上。
这个是我一开始写的用来读取png或者jpg图片,然后输入到网络里的代码:

def image_input(image_path):
    height =224
    width = 224
    image = tf.image.decode_png(tf.read_file(image_path), channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.image.resize_images(image, [height,width])
    return image

这里tf.image.decode_png 得到的是uint8格式,范围在0-255之间,经过convert_image_dtype 就会被转换为区间在0-1之间的float32格式,网络输出的结果也没毛病。

在ROS node里,因为要从我的相机那里订阅图片的message,所以用了cv3_bridge先把ros msg转换成opencv的图片格式,再转成numpy.array输入到placeholder里。我犯的错误就是把placeholder的数据类型写成了float32。

self.images_original = tf.placeholder(tf.uint8, [None,None,3], name="input_images")
self.image = self.image_process(self.images_original)
...

    def image_process(self, image):
        height =224
        width = 224
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        image = tf.image.resize_images(image, [height,width])
        return image

这样的话,0-255区间的uint8格式的numpy.array被直接转成了float32格式,而tf.image.convert_image_dtype(image, dtype) 对于输入为float类型的数认为已经处在0-1区间,就不会进行scale了,所以最后输入到网络中的图像并不是0-1区间的值。

作者:记忆力衰退来写博客的李同学
来源:CSDN
原文:https://blog.csdn.net/Cyril__Li/article/details/78968425
版权声明:本文为博主原创文章,转载请附上博文链接!

你可能感兴趣的:(tf.image.convert_image_dtype(image, dtype)踩坑注意)