paddle读取CIFAR10数据集

paddle中的数据集是按照生成器的方式进行读取,比较方便

import paddle
from paddle import fluid
import numpy as np

reader=paddle.dataset.cifar.train10()

def creat_batch(reader,batch_size,shuffle_size=False):
    label_batch=[]
    iamge_flatten_batch=[]
    if shuffle_size==False:
        Generator=fluid.io.batch(reader,batch_size)
        for i in Generator():
            label_batch_curr=[]
            iamge_flatten_batch_curr=[]
            for curr in i:
                label_batch_curr.append(curr[1])
                iamge_flatten_batch_curr.append(curr[0])
            label_batch.append(np.array(label_batch_curr).reshape(-1,1))
            iamge_flatten_batch.append(np.array(iamge_flatten_batch_curr))
            
    else:
        Generator=fluid.io.batch(paddle.reader.shuffle(reader,shuffle_size),batch_size)
        for i in Generator():
            label_batch_curr=[]
            iamge_flatten_batch_curr=[]
            for curr in i:
                label_batch_curr.append(curr[1])
                iamge_flatten_batch_curr.append(curr[0])
            label_batch.append(np.array(label_batch_curr).reshape(-1,1))
            iamge_flatten_batch.append(np.array(iamge_flatten_batch_curr))
    return iamge_flatten_batch,label_batch
if __name__=="__main__":
    image,label=creat_batch(reader,100,20)
    print("shape of {0}th flatten_data_batch is {1}".format(1,image[0].shape))
    print("shape of {0}th label_batch is {1}".format(1,label[0].shape))

你可能感兴趣的:(paddlepaddle)