Tensorflow图像预处理,Numpy读取数据踩坑

在Tensorflow图片数据读入的时候,往往会遇到各种数据类型上的subtle problem,今天遇到的是在将图片转换成tfrecord过程中,读取图片的问题。最后竟然发现。。。错误发生在Numpy对字符串的处理上。原来是为了与C兼容,np.array 会把末string末尾的‘\x00’截掉,要将图片数据(十进制的string格式存储)用tobytes()转换为十六进制,并存成用字符串'\x92\x99\...‘这样,我需要将图片list转换为array来进行后面的乱序处理。在这个把list变成ndarray的过程中,出现了一些问题。


图片数据读取发现保存的图片数据与这是图片数据维数(227×227×3=154587)不相符,总是会错在一些图像像素数小于154587。使用assert(len(img) ==154587)语句后运行,控制台如下报错:

Traceback (most recent call last):
  File "/home/mokii/RGB-D/SUN-AlexNet/tmp.py", line 280, in
    img2tfrecords(['/home/mokii/RGB-D/SUNRGBD/statistic/trainrgblist.csv'])
  File "/home/mokii/RGB-D/SUN-AlexNet/tmp.py", line 96, in img2tfrecords
    ndarray2tfrecords(train, '/home/mokii/RGB-D/SUNRGBD/data/train/rgb/trainrgb.tfrecords')
  File "/home/mokii/RGB-D/SUN-AlexNet/tmp.py", line 53, in ndarray2tfrecords
    assert (len(img) == 154587)
AssertionError

代码如下,其实这是用经典方法改编的,其他地方都是对的,main函数包括了tfrecords的制作和debug时对decode出的数据的长度的检验代码,最后发现错误就是出在注释代码处啦:

# -*- coding: utf-8 -*-
import os
import csv
import numpy as np
from PIL import Image
import tensorflow as tf
import sys

reload(sys)
sys.setdefaultencoding('utf8')

# todo wait to write
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 300
Test_num = 0

file_num = 1
image_size = 227


def createcsv(folderlist):
    for imgfolder in folderlist:
        csvfile = file(imgfolder + '.csv', 'w+')
        imglist = os.listdir(imgfolder)
        writer = csv.writer(csvfile)
    for filename in imglist:
        writer.writerow([(imgfolder + '/' + filename), filename[0:3] + "\n"])
        csvfile.close()


def ndarray2tfrecords(ndarray, tfpath):
    writer = tf.python_io.TFRecordWriter(tfpath)
    for img in ndarray:
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(img[1])])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img[0]]))
        # 这里通过改成[img[0:1].tobytes()}或[img[0:1].tostring()]就对了
        }))
        writer.write(example.SerializeToString())

        # BEGIN TODO:
        print('Here in ndarray2tfrecords')
        imgs = example.features.feature['img_raw'].bytes_list.value
        labels = example.features.feature['label'].int64_list.value

        img = imgs[0]
        label = labels[0]

        assert (len(imgs) == 1)
        assert (len(labels) == 1)
        print('len(img) =', len(img))
        assert (len(img) == 154587)

    # End TODO

    writer.close()


def img2tfrecords(file_list):
    array = []

    for i in file_list:
        with open(i, 'rb') as f:
            reader = csv.reader(f)
            # cnt = 0
            for line in reader:
                img_path = line[0]
                img = Image.open(img_path)
                img = img.resize((image_size, image_size))
                img = np.asarray(img)
                img = np.require(img, dtype=np.uint8, requirements='C')
                if img.ndim != 3:
                    print 'Error'
                    print img_path
                img_raw = img.tobytes()
                img_class = line[1]
                categories = ['bathroom', 'bedroom', 'classroom', 'computer_room', 'conference_room', 'corridor',
                              'dining_area', 'dining_room'
                    , 'discussion_area', 'furniture_store', 'home_office', 'kitchen', 'lab', 'lecture_theatre',
                              'library', 'living_room'
                    , 'office', 'rest_space', 'study_space']
                for i in xrange(19):
                    if categories[i] == img_class:
                        label = int(i)
                array.append([img_raw, label])

    array = np.array(array)
    perm = np.arange(7984)
    np.random.shuffle(perm)
    train = array[perm[:7200]]
    validation = array[perm[7200:]]

    ndarray2tfrecords(train, '/home/mokii/RGB-D/SUNRGBD/data/train/rgb/trainrgb.tfrecords')
    ndarray2tfrecords(validation, '/home/mokii/RGB-D/SUNRGBD/data/validation/rgb/validationrgb.tfrecords')


def inputs(file_list, batch_size):
    filename_queue = tf.train.string_input_producer(file_list)

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)


features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw': tf.FixedLenFeature([], tf.string),
                                   })

img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [image_size, image_size, 3])
# img = tf.cast(img, tf.float32)
label = tf.cast(features['label'], tf.int64)
# label = tf.reshape(label, [1])

min_fraction_of_examples_in_queue = 0.5
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
# print('label:',label)
return _generate_image_and_label_batch(img, label, min_queue_examples, batch_size)


def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size):
    num_preprocess_threads = 5

    images, label_batch = tf.train.batch(
        [image, label], batch_size=1, num_threads=1, capacity=1)

    return images, label_batch


if __name__ == "__main__":
    # createcsv(folder_list)
    img2tfrecords(['/home/mokii/RGB-D/SUNRGBD/statistic/trainrgblist.csv'])

    print 'DONE!'

    dataset = "/home/mokii/RGB-D/SUNRGBD/data/train/rgb/trainrgb.tfrecords"
    cnt = 0
    for serialized_example in tf.python_io.tf_record_iterator(dataset):
        cnt += 1
        print cnt

    example = tf.train.Example()
    example.ParseFromString(serialized_example)
    imgs = example.features.feature['img_raw'].bytes_list.value
    labels = example.features.feature['label'].int64_list.value

    img = imgs[0]
    label = labels[0]

    assert (len(imgs) == 1)
    assert (len(labels) == 1)
    # print('len(img) =', len(img))
    assert (len(img) == 154587)
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [image_size, image_size, 3])
    # img = tf.cast(img, tf.float32)
    label = tf.cast(label, tf.int64)



经过搜索,最后得出,问题是这样发生的,真是有点无语啦:


>>> a='\x01\x00'
>>> a = np.array([a])
>>> a
array(['\x01'], 
      dtype='|S2')
>>> a.tobytes()
'\x01\x00'
>>> a[0].tobytes()
'\x01'
>>> a[0:1].tobytes()
'\x01\x00'
>>> 




可以像mi上面注释那样改正Bug,但师兄说这样做还是不是很稳定,其实如果在编码时我们能够绕过这个Numpy字符串自动截断末尾’\0‘的’bug‘那就最好了。



# -*- coding: utf-8 -*-
import os
import csv
import numpy as np
from PIL import Image
import tensorflow as tf
import sys
reload(sys)
sys.setdefaultencoding('utf8')

#todo wait to write
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 300
Test_num = 0

file_num = 1
image_size = 227


def createcsv(folderlist):
    for imgfolder in folderlist:
        csvfile = file(imgfolder+'.csv', 'w+')
        imglist =  os.listdir(imgfolder)
        writer = csv.writer(csvfile)
        for filename in imglist:
            writer.writerow([(imgfolder+ '/' +filename) , filename[0:3] +"\n"])
        csvfile.close()


def ndarray2tfrecords(images, labels, tfpath):
    writer = tf.python_io.TFRecordWriter(tfpath)
    for img, label in zip(images, labels):
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
        }))
        writer.write(example.SerializeToString())
    writer.close()


def img2tfrecords(file_list):
    image_list = []
    label_list = []

    for i in file_list:
        with open(i, 'rb') as f:
             reader = csv.reader(f)
             cnt = 0
             for line in reader:
                 img_path = line[0]
                 img = Image.open(img_path)
                 img = img.resize((image_size, image_size))
                 img_raw = np.asarray(img)
                 if img_raw.ndim != 3:
                     print 'Error'
                     print img_path
                     # im = img.split()
                     #  if len(im) != 3:
                     #     print 'Error'
                     #     print img_path
                     #  img_raw = img.tobytes('raw')
                     # img_class = line[1]
                 categories =['bathroom','bedroom','classroom','computer_room','conference_room','corridor','dining_area','dining_room'
                                ,'discussion_area','furniture_store','home_office','kitchen','lab','lecture_theatre','library','living_room'
                               , 'office','rest_space','study_space']
                 for i in xrange(19):
                     if categories[i] == img_class:
                         label = int(i)
                 # array.append([img_raw, label])
                 image_list.append(img_raw)
                 label_list.append(label)

                 # assert(len(array[-1][0]) == 154587)
                 # cnt += 1
                 # if cnt >= 325:
                 #     break

    image_list= np.array(image_list)
    label_list = np.array(label_list)
    perm = np.arange(7984)
    np.random.shuffle(perm)
    train_img = image_list[perm[:7200]]
    train_label = label_list[perm[:7200]]
    validation_img = image_list[perm[7200:]]
    validation_label = label_list[perm[7200:]]

    # images = image_list
    # labels = label_list

    ndarray2tfrecords(train_img, train_label, '/home/mokii/RGB-D/SUNRGBD/data/train/rgb/trainrgb.tfrecords')
    ndarray2tfrecords(validation_img, validation_label, '/home/mokii/RGB-D/SUNRGBD/data/validation/rgb/validationrgb.tfrecords')
    # ndarray2tfrecords(array, '/home/mokii/RGB-D/SUNRGBD/data/test/rgb/testrgb.tfrecords')


def inputs(file_list,batch_size):

    filename_queue = tf.train.string_input_producer(file_list)

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [image_size, image_size, 3])
    # img = tf.cast(img, tf.float32)
    label = tf.cast(features['label'],tf.int64)
    # label = tf.reshape(label, [1])

    min_fraction_of_examples_in_queue = 0.5
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
    # print('label:',label)
    return _generate_image_and_label_batch(img, label, min_queue_examples, batch_size)


def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size):
    num_preprocess_threads = 5

    images, label_batch = tf.train.shuffle_batch(
         [image, label],
         batch_size=batch_size,
         num_threads=num_preprocess_threads,
         capacity=min_queue_examples + 3 * batch_size,
         min_after_dequeue=min_queue_examples)
    return images, tf.reshape(label_batch, [batch_size, -1])

    images, label_batch = tf.train.shuffle_batch(
         [image, label],
         batch_size=44,
         num_threads=1,
         capacity=2,
         min_after_dequeue=1)

    # images, label_batch = tf.train.batch(
    #     [image, label],
    #     batch_size=1,
    #     num_threads=1,
    #     capacity=1)

    return images, label_batch


if __name__ == "__main__":
    #createcsv(folder_list)
    img2tfrecords(['/home/mokii/RGB-D/SUNRGBD/statistic/trainrgblist.csv'])


    print 'DONE!'

    dataset = "/home/mokii/RGB-D/SUNRGBD/data/train/rgb/trainrgb.tfrecords"
    # dataset = "/home/mokii/RGB-D/SUNRGBD/data/validation/rgb/validationrgb.tfrecords"
    cnt = 0
    for serialized_example in tf.python_io.tf_record_iterator(dataset):
        cnt += 1
        print cnt

        example = tf.train.Example()
        example.ParseFromString(serialized_example)
        imgs = example.features.feature['img_raw'].bytes_list.value
        labels = example.features.feature['label'].int64_list.value

        img = imgs[0]
        label = labels[0]

        assert(len(imgs) == 1)
        assert(len(labels) == 1)
        print('len(img) =', len(img))
        assert(len(img) == 154587)
        img = tf.decode_raw(img, tf.uint8)
        img = tf.reshape(img, [image_size, image_size, 3])
        # img = tf.cast(img, tf.float32)
        label = tf.cast(label, tf.int64)





你可能感兴趣的:(Tensorflow)