读取CIFAR10数据集可视化

前言

想自己重写Dataset类,不通过torchvision.dataset.CIFAR10获取数据集。但是从官网下载的数据集是压缩包形式,直接解压无法得到图片和标签信息,因此参考博客将图片和标签读取出来。

下载数据集

首先可以通过pytorch下载CIFAR10数据集

#train.py

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np


#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                         download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                           shuffle=False, num_workers=0)

# 10000张验证图片
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=4,
                                         shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = val_data_iter.next()
print(val_image.size())
print(train_set.class_to_idx)
classes = ('plane', 'car', 'bird', 'cat',
          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


#显示图像,之前需把validate_loader中batch_size改为4
aaa = train_set.class_to_idx
cla_dict = dict((val, key) for key, val in aaa.items())
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    #plt.imshow(npimg)
    tt = np.transpose(npimg, (1, 2, 0))
    print(tt.shape)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

print(' '.join('%5s' % cla_dict[val_label[j].item()] for j in range(4)))
imshow(utils.make_grid(val_image))

数据集可视化

通过反序列化将数据读取出来

train

readDataTrain.py

import pickle
from imageio import imsave
import numpy as np


def load_file(filename):
    with open(filename, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict



dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']

for k in range(1, 6):
    dic = unpickle("data/cifar-10-batches-py/data_batch_" + str(k))
    dict_image_data = dic[b'data']
    dict_image_labels = dic[b'labels']

    len = dict_image_data.shape[0]

    for i in range(len):
        id = len * (k - 1) + i + 1
        id = str(id).zfill(5)
        imgs = dict_image_data[i]
        labels = dict_image_labels[i]
        imgs_array = np.reshape(imgs, (3, 32, 32))
        imgs_array = imgs_array.transpose(1, 2, 0)
        imsave("data/cifar10/train/imges/" + id + '.jpg', imgs_array)
        with open("data/cifar10/train/labels/" + id + '.txt', 'w') as f:
            f.write(str(dict_image_labels[i]))


test

readDataTest.py

import pickle
from imageio import imsave
import numpy as np


def load_file(filename):
    with open(filename, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict



dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']


dic = unpickle("data/cifar-10-batches-py/test_batch")
dict_image_data = dic[b'data']
dict_image_labels = dic[b'labels']

len = dict_image_data.shape[0]

for i in range(len):
    id = i + 1
    id = str(id).zfill(5)
    imgs = dict_image_data[i]
    labels = dict_image_labels[i]
    imgs_array = np.reshape(imgs, (3, 32, 32))
    imgs_array = imgs_array.transpose(1, 2, 0)
    imsave("data/cifar10/test/imges/" + id + '.jpg', imgs_array)
    with open("data/cifar10/test/labels/" + id + '.txt', 'w') as f:
        f.write(str(dict_image_labels[i]))


效果

读取CIFAR10数据集可视化_第1张图片
读取CIFAR10数据集可视化_第2张图片
读取CIFAR10数据集可视化_第3张图片
读取CIFAR10数据集可视化_第4张图片
标签是0-9之间的数字
读取CIFAR10数据集可视化_第5张图片
标签的对应关系

{'airplane': 0, 'automobile': 1, 'bird': 2, 
 'cat': 3, 'deer': 4, 'dog': 5, 
 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}

参考资料

手把手教你CIFAR数据集可视化

你可能感兴趣的:(深度学习,python,深度学习)