解析 cifar10 的压缩包到图片

# -*- coding: utf-8 -*-

import os
import numpy as np
import pickle as pk
from PIL import Image

# cifar10训练集目录  from: http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
cifar10_path = "./cifar_train/cifar-10-batches-py/"
train_batchs = [
    cifar10_path + "data_batch_1",
    cifar10_path + "data_batch_2",
    cifar10_path + "data_batch_3",
    cifar10_path + "data_batch_4",
    cifar10_path + "data_batch_5"
]
test_batchs = [cifar10_path + "test_batch"]

# 读取出的图片存放位置
output_path = "./cifar10/"

# label含义
label_list = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

def reader_cifar10(batchs, phase="train"):
    list_lines = []
    arr_cls_img_num = np.zeros(10, dtype=np.int32)
    for file in batchs:
        data = {}
        with open(file, "rb") as fo:
            data = pk.load(fo, encoding="bytes")
            print(data.keys())
            
        for i in range(len(data[b"data"])):
            img = data[b"data"][i]
            label = data[b"labels"][i]
            label_name = label_list[label]
            
            # 重建rgb彩色图片
            img = img.reshape(3, 32, 32)
            r = Image.fromarray(img[0]).convert("L")
            g = Image.fromarray(img[1]).convert("L")
            b = Image.fromarray(img[2]).convert("L")
            new_img = Image.merge("RGB", (r, g, b))
            
            result_path = os.path.join(output_path, phase, label_name)
            if not os.path.exists(result_path):
                os.makedirs(result_path)
            
            arr_cls_img_num[label] += 1
            img_name = "%s-%d_%d.png"%(label_name, label, arr_cls_img_num[label])
            img_filename = os.path.join(result_path, img_name)
            new_img.save(img_filename)
            list_lines.append("%s %d\n"%(img_filename.replace(output_path, ""), label))

    label_file = os.path.join(output_path, "%s_label.txt"%phase)
    fp_label = open(label_file, "wt")
    fp_label.writelines(list_lines)
    fp_label.close()

if __name__ == "__main__":
    # 读取训练集
    reader_cifar10(train_batchs, phase="train")
    # 读取测试集
    reader_cifar10(test_batchs, phase="test")

参考代码:

如何用python解析cifar10数据集图片 | 闫金钢的Blog

你可能感兴趣的:(编程,模式识别,cifar10)