想自己重写Dataset类,不通过torchvision.dataset.CIFAR10获取数据集。但是从官网下载的数据集是压缩包形式,直接解压无法得到图片和标签信息,因此参考博客将图片和标签读取出来。
首先可以通过pytorch下载CIFAR10数据集
#train.py
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=False, num_workers=0)
# 10000张验证图片
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=4,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = val_data_iter.next()
print(val_image.size())
print(train_set.class_to_idx)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#显示图像,之前需把validate_loader中batch_size改为4
aaa = train_set.class_to_idx
cla_dict = dict((val, key) for key, val in aaa.items())
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
#plt.imshow(npimg)
tt = np.transpose(npimg, (1, 2, 0))
print(tt.shape)
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
print(' '.join('%5s' % cla_dict[val_label[j].item()] for j in range(4)))
imshow(utils.make_grid(val_image))
通过反序列化将数据读取出来
readDataTrain.py
import pickle
from imageio import imsave
import numpy as np
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']
for k in range(1, 6):
dic = unpickle("data/cifar-10-batches-py/data_batch_" + str(k))
dict_image_data = dic[b'data']
dict_image_labels = dic[b'labels']
len = dict_image_data.shape[0]
for i in range(len):
id = len * (k - 1) + i + 1
id = str(id).zfill(5)
imgs = dict_image_data[i]
labels = dict_image_labels[i]
imgs_array = np.reshape(imgs, (3, 32, 32))
imgs_array = imgs_array.transpose(1, 2, 0)
imsave("data/cifar10/train/imges/" + id + '.jpg', imgs_array)
with open("data/cifar10/train/labels/" + id + '.txt', 'w') as f:
f.write(str(dict_image_labels[i]))
readDataTest.py
import pickle
from imageio import imsave
import numpy as np
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']
dic = unpickle("data/cifar-10-batches-py/test_batch")
dict_image_data = dic[b'data']
dict_image_labels = dic[b'labels']
len = dict_image_data.shape[0]
for i in range(len):
id = i + 1
id = str(id).zfill(5)
imgs = dict_image_data[i]
labels = dict_image_labels[i]
imgs_array = np.reshape(imgs, (3, 32, 32))
imgs_array = imgs_array.transpose(1, 2, 0)
imsave("data/cifar10/test/imges/" + id + '.jpg', imgs_array)
with open("data/cifar10/test/labels/" + id + '.txt', 'w') as f:
f.write(str(dict_image_labels[i]))
{'airplane': 0, 'automobile': 1, 'bird': 2,
'cat': 3, 'deer': 4, 'dog': 5,
'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
手把手教你CIFAR数据集可视化