Tensorflow(一)训练自己的数据—制作自己的数据集

1..采集图片

以猫、狗分别采集20张为例子

图片下载地址:已上传,审核中

以下为存放的路径:(使用ubuntu自带screenshot进行截图)

Tensorflow(一)训练自己的数据—制作自己的数据集_第1张图片


2.制作record形式数据集
简易版本

import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


cwd = '/home/cae_b8/Documents/project1/cats_vs_dog/data/train/'
classes = {'dog', 'cat'}  # 人为设定2类


writer = tf.python_io.TFRecordWriter("dog_vs_cat_train.tfrecords")  # 要生成的文件

for index, name in enumerate(classes):
    class_path = cwd + name + '/'
    #print(class_path)
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name  # 每一个图片的地址
        #print(img_path)
        img = Image.open(img_path)
        img = img.resize((128, 128))
        img_raw = img.tobytes()  # 将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))  # example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  # 序列化为字符串

writer.close()

 另一版本(生成txt文件)

import os
import numpy as np
import math
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt

file_dir = '/home/cae_b8/Documents/project2/data_set/train/'
txt_path = '/home/cae_b8/Documents/project2/data_set/train/data.txt'
writer = tf.python_io.TFRecordWriter("/home/cae_b8/Documents/project2/data_set/train/train.tfrecords")

dogs = []
label_dogs = []
cats = []
label_cats = []
NUM = 38

fw = open(txt_path,'w')
for file in os.listdir(file_dir):
    name = file.split(sep='.')
    if 'cat' in name[0]:
        cats.append(file_dir + file)
        label_cats.append(0)
        fw.write(file_dir + file + ' ' + '0\n')
    else:
        if 'dog' in name[0]:
            dogs.append(file_dir + file)
            label_dogs.append(1)
            fw.write(file_dir + file + ' ' + '1\n')
print('generate txt file successfully')
fw.close()

image_list = np.hstack((cats, dogs))
label_list = np.hstack((label_cats, label_dogs))

print("There are %d dogs\nThere are %d cats\n" % (len(dogs), len(cats)), end="")

temp = np.array([image_list,label_list])
temp = temp.transpose()
np.random.shuffle(temp)

image_list = list(temp[ :, 0 ])
label_list = list(temp[ :, 1 ])
label_list = [int(i) for i in label_list]


j = 0
for j in range(NUM):

    image_name = image_list[j]
    label_name = label_list[j]
    image = Image.open(image_name)
    image = image.resize((128,128))
    img_raw = image.tobytes()
    example = tf.train.Example(features=tf.train.Features(feature={
        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_name])),
        'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    }))

    writer.write(example.SerializeToString())
    print('Creating train record in ',j+1)

writer.close()
print("Create train_record successful!")

3.将record格式图像显示出来

import os
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np


show_image = '/home/cae_b8/Documents/project2/data_set/train/show_image/'
file_name = '/home/cae_b8/Documents/project2/data_set/train/train.tfrecords'
filename_queue = tf.train.string_input_producer([file_name])  # 生成一个queue队列

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  # 返回文件名和文件
features = tf.parse_single_example(
    serialized_example,
    features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img_raw': tf.FixedLenFeature([], tf.string),
    }
)  # 将image数据和label取出来

image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [128, 128, 3])
# image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:  # 开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord = tf.train.Coordinator()  # 创建一个协调器,管理线程
    threads = tf.train.start_queue_runners(coord=coord)  # 启动QueueRunner, 此时文件名队列已经进队
    NUM = 1
    for i in range(38):

        example, l = sess.run([image, label])  # 在会话中取出image和label
        img = Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
        img.save(show_image + str(i) + '_''Label_' + str(l) + '.jpg')  # 存下图片

        print('Creating image in ', NUM)
        NUM += 1
    coord.request_stop()
    coord.join(threads)

Tensorflow(一)训练自己的数据—制作自己的数据集_第2张图片


 

你可能感兴趣的:(tensorflow)