说明直接离线加载cifar10到Pytorch
'''
直接加载6个文件到pytorch
data_batch_1
data_batch_2
data_batch_3
data_batch_4
data_batch_5
test_batch
'''
import os
import cv2
import pickle
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torch.autograd import Variable
import torch.utils.data as Data
from torchvision import transforms
#加载cifar10的数据
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = pickle.load(f,encoding='latin1')
X = datadict['data']
Y = datadict['labels']
# X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1).astype("float")
X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1)
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" load all of cifar """
xs = []
ys = []
for b in range(1,6):
filename = os.path.join(ROOT, 'data_batch_%d' % (b))
X, Y = load_CIFAR_batch(filename)
xs.append(X)
ys.append(Y)
Xtrain = np.concatenate(xs)#使变成行向量
Ytrain = np.concatenate(ys)
del X, Y
Xtest, Ytest = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtrain, Ytrain, Xtest, Ytest
class DealDataset(Data.Dataset):
"""
读取数据、初始化数据
"""
def __init__(self, root, train=True, transform=None):
if train:
# 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
(train_set, train_labels, _, _) = load_CIFAR10(root)
self.train_set = train_set
self.train_labels = train_labels
else:
(_, _, test_set, test_labels) = load_CIFAR10(root)
self.test_set = test_set
self.test_labels = test_labels
self.transform = transform
self.train = train
def __getitem__(self, index):
if self.train:
img, target = self.train_set[index], int(self.train_labels[index])
else:
img, target = self.test_set[index], int(self.test_labels[index])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
if self.train:
return len(self.train_set)
else:
return len(self.test_set)
root = r'E:\cifar-10-python\cifar-10-batches-py'
batch_size = 8
# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(root, train=True, transform=transforms.ToTensor())
testDataset = DealDataset(root, train=False, transform=transforms.ToTensor())
# 训练数据和测试数据的装载
train_loader = Data.DataLoader(
dataset=trainDataset,
batch_size=batch_size, # 一个批次可以认为是一个包,每个包中含有batch_size张图片
shuffle=False,
)
test_loader = Data.DataLoader(
dataset=testDataset,
batch_size=batch_size,
shuffle=False,
)
if __name__ == '__main__':
# 这里trainDataset包含:train_labels, train_set等属性; 数据类型均为ndarray
print(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')
print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')
# 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset
# dataset中又包含train_labels, train_set等属性; 数据类型均为ndarray
print(f'train_loader.batch_size: {train_loader.batch_size}\n')
print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')
print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')
# # 可视化1,使用OpenCV
# images, lables = next(iter(train_loader))
# img = torchvision.utils.make_grid(images, nrow = 10)
# img = img.numpy().transpose(1, 2, 0)
# # OpenCV默认为BGR,这里img为RGB,因此需要对调img[:,:,::-1]
# cv2.imshow('img', img[:,:,::-1])
# cv2.waitKey(0)
# 可视化2,使用plt
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
fig = plt.figure(figsize=(4, 4))
for idx in np.arange(batch_size):
ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])
# ax.imshow(np.squeeze(images[idx]), cmap='gray')
# a = images[idx]
# b = images[idx].transpose(1, 2, 0)
# ax.imshow(images[idx].transpose(1, 2, 0), cmap='RGB')
ax.imshow(images[idx].transpose(1, 2, 0))
ax.set_title(classes[labels[idx]])
plt.show()
运行结果
显示图