将当前目录下cifar-10-batches-py转换成train图片和test图片,并同时生成train.txt 和 test.txt

from __future__ import print_function
import matplotlib.pyplot as pyplot
import PIL.Image as Image
import pickle
import numpy as np
import random
import os,re

def unpickle(file):
    # data:a 10000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image.
    # The first 1024 entries contain the red channel values, the next 1024 the green,
    # and the final 1024 the blue. The image is stored in row-major order,
    # so that the first 32 entries of the array are the red channel values of the first row of the image.

    # labels:a list of 10000 numbers in the range 0-9.
    # The number at index i indicates the label of the ith image in the array data.
    fo = open(file, 'rb')
    dict = pickle.load(fo,encoding = 'bytes')
    train_labels = dict[b'labels']
    train_array = dict[b'data']
    fo.close()
    return train_labels, train_array

def saveImg(save_path, class_index):

    num = 0
    arr = np.array([])
    train_or_test = re.split('/', save_path)[-3]
    if train_or_test == 'train':
        labels = train_labels
        source_arr = train_array
    else:
        labels = test_labels
        source_arr = test_array

    for i in range (len(labels)):
        if labels[i] == class_index:
            arr = np.concatenate((arr, source_arr[i]))
            num = num + 1

    arr = arr.reshape(num, 3, 32, 32)
    # 保存图片

    # 文件夹不存在则创建
    isExists = os.path.exists(save_path)
    # 判断结果
    if not isExists:
        os.makedirs(save_path) 

    for index in range(num):
        a = arr[index]
        # 得到RGB通道
        r = Image.fromarray(a[0]).convert('L')
        g = Image.fromarray(a[1]).convert('L')
        b = Image.fromarray(a[2]).convert('L')
        image = Image.merge("RGB", (r, g, b))
        # 显示图片
        # pyplot.imshow(image)
        # pyplot.show()

        image.save(save_path + str(index) + ".png", 'png')

# 创建 .txt 文件
def initTxt(train_or_test_set_path):

    folders = os.listdir(train_or_test_set_path)
    train_or_test = re.split('/', train_or_test_set_path)[-2]

    arr = []
    for folder in folders:
        class_index = 0
        files = os.listdir(train_or_test_set_path + folder)
        for file in files:
            arr.append(folder + "/" + file + ' ' + str(class_index))

    # 打乱顺序,写入 .txt 文件
    random.shuffle(arr)
    with open(my_path + '/' + train_or_test + '.txt', mode='w') as f:
        for i in arr:
            f.write(i + '\n')

# 获取当前Python文件所在路径
my_path = os.getcwd()

train_cifar_10_path = my_path + "/cifar-10-batches-py/data_batch_3"
train_labels, train_array = unpickle(train_cifar_10_path)
# 测试集路径
test_cifar_10_path = my_path + "/cifar-10-batches-py/test_batch"
test_labels, test_array = unpickle(test_cifar_10_path)

saveImg(my_path + "/train/airplane/", 0)
saveImg(my_path + "/test/airplane/", 0)

initTxt(my_path + '/train/')
initTxt(my_path + '/test/')

你可能感兴趣的:(caffe)