自定义制作python版本的CIFAR数据集

自定义制作python版本的CIFAR数据集

CIFAR-10/CIFAR-100数据集

1、准备图像

(以制作小数据集为例,便于理解)

这里自定义制作的数据集只包含2个类:dog,parrot,每个类有121张图像。数据集共有242张图像,测试图像30张,训练图像212张。将数据集分为1个测试批次和2个训练批次。测试批次包含每个类的15张图像。每个训练批次包含106张图像,但是其中属于各个类的图像数量随机(即不同训练批次中相同类的图像数量不一定相等)。

图片的命名规则为 “label_类别名_编号.jpg”,这里规定,label为0时类别名为dog,label为1时类别名为parrot。

 

2、数据集理解

首先调整所有图像的大小,这里调整为256×256(img_dim=256)。

def img_resize(img_dir, img_dim):
    '''Args:
        img_dir: 该批次图像文件夹路径
        img_dim: 调整后的大小
    '''
    img_resized_dir = img_dir + '_resize'  # 调整后图像的保存路径
    os.makedirs(img_resized_dir, exist_ok=True)
    img_list = os.listdir(img_dir)
    for img_name in img_list:
        img_path = os.path.join(img_dir, img_name)
        img = Image.open(img_path)
        x_new = img_dim
        y_new = img_dim
        out = img.resize((x_new, y_new), Image.ANTIALIAS)
        out.save('{}/{}.jpg'.format(img_resized_dir, img_name))
    print('Images in {} are resized as {}×{}.\n'.format(img_dir, img_dim, img_dim))
    return img_resized_dir

cifar数据集中每个批次文件包含一个字典,字典内有4个键,分别是:'batch_label','data','filenames','labels'。可以使用以下代码查看。

  • 'batch_label' = 当前批次的名字。
  • 'data' = 形状为(106,256×256×3)的uint8的numpy数组。数组的每行存储一张图像的数字信息,按通道顺序为红、绿、蓝存储,每个通道按行优先。
  • 'filenames' = 一个包含该批次所有图像名称的列表,长度为106。
  • 'labels' = 一个取值为0、1的列表,长度为106。索引i处的数字为第i个图像的标签。标签0表示dog,标签1表示parrot。
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin-1')
    return dict


cc = unpickle("images_cifar/data_batch_1")
print(cc.keys())
print(cc['filenames'])

 

3、完整运行代码

自定义的图像批次保存在images文件夹中,生成的cifar数据集文件保存在images_cifar文件夹中。

自定义制作python版本的CIFAR数据集_第1张图片自定义制作python版本的CIFAR数据集_第2张图片

from PIL import Image
from numpy import *

import numpy as np
import os
import pickle


def img_resize(img_dir, img_dim):
    '''Args:
        img_dir: 该批次图像文件夹路径
        img_dim: 调整后的大小
    '''
    img_resized_dir = img_dir + '_resize'  # 调整后图像的保存路径
    os.makedirs(img_resized_dir, exist_ok=True)
    img_list = os.listdir(img_dir)
    for img_name in img_list:
        img_path = os.path.join(img_dir, img_name)
        img = Image.open(img_path)
        x_new = img_dim
        y_new = img_dim
        out = img.resize((x_new, y_new), Image.ANTIALIAS)
        out.save('{}/{}.jpg'.format(img_resized_dir, img_name))
    print('Images in {} are resized as {}×{}.\n'.format(img_dir, img_dim, img_dim))
    return img_resized_dir


def get_filenames_and_labels(img_resized_dir):
    filenames = []
    labels = []
    img_list = os.listdir(img_resized_dir)
    for img_name in img_list:
        filenames.append(img_name.encode('utf-8'))
        img_name_str = img_name.split('.')[0]
        label = int(img_name_str.split('_')[0])
        labels.append(label)
    return filenames, labels


def get_img_data(img_resized_dir):
    imgs = []
    # count = 0
    img_list = os.listdir(img_resized_dir)
    for img_name in img_list:
        img_path = os.path.join(img_resized_dir, img_name)
        img = Image.open(img_path)
        r, g, b = img.split()
        r_array = np.array(r, dtype=np.uint8).flatten()
        g_array = np.array(g, dtype=np.uint8).flatten()
        b_array = np.array(b, dtype=np.uint8).flatten()
        img_array = concatenate((r_array, g_array, b_array))
        # print(img_array.shape)
        imgs.append(img_array)
        # count += 1
        # print('Get {} images of {}'.format(count, img_resized_dir))
    imgs = np.array(imgs, dtype=np.uint8)

    return imgs


if __name__ == '__main__':
    img_dir_names = ['test_batch']  # 1个测试批次
    num_data_batch = 2  # 2个训练批次
    for i in range(1, num_data_batch + 1):
        img_dir_names.append('data_batch_' + str(i))

    count = 0
    for img_dir_name in img_dir_names:
        img_dir = 'images/' + img_dir_name
        filepath = 'images_cifar/' + img_dir_name
        img_resized_dir = img_resize(img_dir, img_dim=256)

        data_batch = {}

        if 'test' in filepath:
            data_batch['batch_label'.encode('utf-8')] = 'testing batch 1 of 1'.encode('utf-8')
        else:
            count += 1
            batch_label = 'training batch ' + str(count) + ' of ' + str(num_data_batch)
            data_batch['batch_label'.encode('utf-8')] = batch_label.encode('utf-8')

        filenames, labels = get_filenames_and_labels(img_resized_dir)
        data = get_img_data(img_resized_dir)

        data_batch['filenames'.encode('utf-8')] = filenames
        data_batch['labels'.encode('utf-8')] = labels
        data_batch['data'.encode('utf-8')] = data

        with open(filepath, 'wb') as f:
            pickle.dump(data_batch, f)

    img_classes = 'images_cifar/batches.meta'
    label_names = {0: 'dog', 1: 'parrot'}
    with open(img_classes, 'wb') as f:
        pickle.dump(label_names, f)


# def unpickle(file):
#     import pickle
#     with open(file, 'rb') as fo:
#         dict = pickle.load(fo, encoding='latin-1')
#     return dict
#
#
# cc = unpickle("C:/Users/lenovo/.keras/datasets/cifar-10-batches-py/data_batch_1")
# print(cc.keys())
# print(cc['filenames'])

 

你可能感兴趣的:(其他)