Pytorch cifar10离线加载二进制文件

 

说明直接离线加载cifar10到Pytorch

'''
直接加载6个文件到pytorch
    data_batch_1
    data_batch_2
    data_batch_3
    data_batch_4
    data_batch_5
    test_batch

'''


import os
import cv2
import pickle
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torch.autograd import Variable
import torch.utils.data as Data
from torchvision import transforms



#加载cifar10的数据
def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = pickle.load(f,encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        # X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1).astype("float")
        X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1)
        Y = np.array(Y)
    return X, Y

def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        filename = os.path.join(ROOT, 'data_batch_%d' % (b))
        X, Y = load_CIFAR_batch(filename)
        xs.append(X)
        ys.append(Y)

    Xtrain = np.concatenate(xs)#使变成行向量
    Ytrain = np.concatenate(ys)

    del X, Y

    Xtest, Ytest = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))

    return Xtrain, Ytrain, Xtest, Ytest


class DealDataset(Data.Dataset):
    """
        读取数据、初始化数据
    """
    def __init__(self, root, train=True, transform=None):
        if train:
            # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
            (train_set, train_labels, _, _) = load_CIFAR10(root)
            self.train_set = train_set
            self.train_labels = train_labels
        else:
            (_, _, test_set, test_labels) = load_CIFAR10(root)
            self.test_set = test_set
            self.test_labels = test_labels

        self.transform = transform
        self.train = train

    def __getitem__(self, index):

        if self.train:
            img, target = self.train_set[index], int(self.train_labels[index])
        else:
            img, target = self.test_set[index], int(self.test_labels[index])


        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        if  self.train:
            return len(self.train_set)
        else:
            return len(self.test_set)

root = r'E:\cifar-10-python\cifar-10-batches-py'
batch_size = 8

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(root, train=True, transform=transforms.ToTensor())
testDataset = DealDataset(root, train=False, transform=transforms.ToTensor())

# 训练数据和测试数据的装载
train_loader = Data.DataLoader(
    dataset=trainDataset,
    batch_size=batch_size, # 一个批次可以认为是一个包,每个包中含有batch_size张图片
    shuffle=False,
)

test_loader = Data.DataLoader(
    dataset=testDataset,
    batch_size=batch_size,
    shuffle=False,
)

if __name__ == '__main__':

    # 这里trainDataset包含:train_labels, train_set等属性;  数据类型均为ndarray
    print(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')
    print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')


    # 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset
    # dataset中又包含train_labels, train_set等属性;  数据类型均为ndarray
    print(f'train_loader.batch_size: {train_loader.batch_size}\n')
    print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')
    print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')


    # # 可视化1,使用OpenCV
    # images, lables = next(iter(train_loader))
    # img = torchvision.utils.make_grid(images, nrow = 10)
    # img = img.numpy().transpose(1, 2, 0)
    # # OpenCV默认为BGR,这里img为RGB,因此需要对调img[:,:,::-1]
    # cv2.imshow('img', img[:,:,::-1])
    # cv2.waitKey(0)

    # 可视化2,使用plt
    dataiter = iter(train_loader)
    images, labels = dataiter.next()
    images = images.numpy()

    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']


    fig = plt.figure(figsize=(4, 4))

    for idx in np.arange(batch_size):
        ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])
        # ax.imshow(np.squeeze(images[idx]), cmap='gray')
        # a = images[idx]
        # b = images[idx].transpose(1, 2, 0)
        # ax.imshow(images[idx].transpose(1, 2, 0), cmap='RGB')
        ax.imshow(images[idx].transpose(1, 2, 0))
        ax.set_title(classes[labels[idx]])

    plt.show()

 

运行结果

Pytorch cifar10离线加载二进制文件_第1张图片

显示图

Pytorch cifar10离线加载二进制文件_第2张图片

你可能感兴趣的:(Pytorch)