手写字母数据集转换为.pickle文件

首先是数据集,我上传了相关的资源,https://download.csdn.net/download/fanzonghao/10566701

  转换代码如下:

import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpig
import imageio
import pickle
"""
函数功能:将notMNIST_large和notMNIST_small的图片生成对应的.pickle文件
"""
def load_letter(folder,min_num_images,image_size):
    image_files=os.listdir(folder)
    print(folder)
    #定义存放图片的numpy类型
    dataset=np.ndarray(shape=(len(image_files),image_size,image_size),dtype=np.float32)
    num_image=0
    for image in image_files:
        image_file=os.path.join(folder,image)
        try:
            image_data=(mpig.imread(image_file)-0.5)/1
            assert image_data.shape==(image_size,image_size)
            dataset[num_image,:,:]=image_data
            num_image+=1
        except(IOError,ValueError)as e:
            print('could not read:',image_file,e,'skipping')
    #提示所需样本数少
    if num_image

 打印生成的结果:

手写字母数据集转换为.pickle文件_第1张图片

将两个数据集的手写字母生成的.pickle转换成整个.pickle数据集,这样在使用的时候方便直接调用,代码如下:

import numpy as np
import data_deal
import os
import pickle
"""
函数功能:功能1:调用把图片文件生成pickle文件的
        功能2:通过把生成的pickle文件调用生成train_dataset和valid_dataset和test_dataset
"""
#生成.pickle文件  没有的时候才执行
# data_deal.produce_train_test_pickle()

"""
生成所需数据的np array
"""
def make_array(rows,img_size):
    if rows:
        dataset=np.ndarray(shape=(rows,img_size,img_size),dtype=np.float32)
        labels=np.ndarray(shape=(rows,),dtype=np.int32)
    else:
        dataset, labels=None,None
    return dataset,labels
"""
生成训练集和测试集 dataset
"""
def produce_train_test_datasets(pickle_files,train_size,valid_size=0):
    num_classes=len(pickle_files)
    valid_dataset,valid_lable=make_array(valid_size, img_size=28)
    train_dataset, train_lable = make_array(train_size, img_size=28)
    #小数据量存储近train_dataset和valid_dataset
    valid_size_per_class = valid_size // num_classes
    train_size_per_class = train_size // num_classes
    start_v,start_t=0,0
    end_v,end_t=valid_size_per_class,train_size_per_class
    end_l=valid_size_per_class+train_size_per_class
    for lable,pickle_file in enumerate(pickle_files):
        with open(pickle_file,'rb') as f:
            #载入每个字母的pickle
            every_letter_samples=pickle.load(f)
            #打乱顺序 (7000,28,28)对下一层进行打乱操作 直接改变原有的顺序
            np.random.shuffle(every_letter_samples)
            #制作验证集
            if valid_dataset is not None:#放入test数据不需要valid_dataset
                valid_letter=every_letter_samples[:valid_size_per_class,:,:]
                valid_dataset[start_v:end_v,:,:]=valid_letter
                valid_lable[start_v:end_v]=lable
                start_v+=valid_size_per_class
                end_v+=valid_size_per_class
            # 制作训练集
            train_letter = every_letter_samples[valid_size_per_class:end_l, :, :]
            train_dataset[start_t:end_t, :, :] = train_letter
            train_lable[start_t:end_t] = lable
            start_t += train_size_per_class
            end_t += train_size_per_class
    return valid_dataset,valid_lable,train_dataset,train_lable
"""
实现训练样本 测试样本的A~j顺序打乱
"""
def random_letter(dataset,labels):
    #获取打乱的索引
    permutation=np.random.permutation(labels.shape[0])
    dataset=dataset[permutation,:,:]
    labels=labels[permutation]
    return dataset,labels
"""
生成最终的notMNIST.pickle 包含train valid test
"""
def notMNIST_pickle():
    train_size=200000
    valid_size=1000
    test_size=1000
    train_dir = './data/notMNIST_large/Pickles'
    train_pickle_dir=[os.path.join(train_dir,i) for i in sorted(os.listdir(train_dir))]
    valid_dataset,valid_lable,train_dataset,train_lable=produce_train_test_datasets(train_pickle_dir,train_size,valid_size)

    test_dir = './data/notMNIST_small/Pickles'
    test_pickle_dir=[os.path.join(test_dir,i) for i in sorted(os.listdir(test_dir))]
    _,_,test_dataset,test_lable=produce_train_test_datasets(test_pickle_dir,test_size)
    print('Training',train_dataset.shape,train_lable.shape)
    print('Validing',valid_dataset.shape,valid_lable.shape)
    print('Testing',test_dataset.shape,test_lable.shape)

    train_dataset,train_label=random_letter(train_dataset,train_lable)
    valid_dataset, valid_label = random_letter(valid_dataset, valid_lable)
    test_dataset, test_label = random_letter(test_dataset, test_lable)
    print('after shuffle training',train_dataset.shape,train_label.shape)
    print('after shuffle validing',valid_dataset.shape,valid_label.shape)
    print('after shuffle testing',test_dataset.shape,test_label.shape)

    all_pickle_file=os.path.join('./data','notMNIST.pickle')
    try:
        with open(all_pickle_file, 'wb') as f:
            save={
                'train_dataset':train_dataset,
                'train_label': train_label,
                'valid_dataset': valid_dataset,
                'valid_label': valid_label,
                'test_dataset': test_dataset,
                'test_label': test_label,
            }
            pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
    except Exception as e:
        print('unable to save data', all_pickle_file, e)

    statinfo=os.stat(all_pickle_file)
    print('Compressed pickle size',statinfo.st_size)
if __name__ == '__main__':
    notMNIST_pickle()

读取.pickle

import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt
#对于x变成(samles,pixs),y变成one_hot (samples,10)
"""
one-hot
"""
def reformat(dataset,labels,imgsize,C):
    dataset=dataset.reshape(-1,imgsize*imgsize).astype(np.float32)
    #one_hot两种写法
    #写法一
    labels=np.eye(C)[labels.reshape(-1)].astype(np.float32)

    #写法二
    #labels=(np.arange(10)==labels[:,None]).astype(np.float32)
    return dataset,labels
"""
读取.pickle文件
"""
def pickle_dataset():
    path='./data/notMNIST.pickle'
    with open(path,'rb') as f:
        restore=pickle.load(f)
        train_dataset=restore['train_dataset']
        train_label = restore['train_label']
        valid_dataset = restore['valid_dataset']
        valid_label = restore['valid_label']
        test_dataset = restore['test_dataset']
        test_label = restore['test_label']
        del restore
        # print('Training:', train_dataset.shape, train_label.shape)
        # print('Validing:', valid_dataset.shape, valid_label.shape)
        # print('Testing:', test_dataset.shape, test_label.shape)
    train_dataset,train_label=reformat(train_dataset,train_label,imgsize=28,C=10)
    valid_dataset,valid_label=reformat(valid_dataset,valid_label,imgsize=28,C=10)
    test_dataset,test_label=reformat(test_dataset,test_label,imgsize=28,C=10)
    # print('after Training:', train_dataset.shape, train_label.shape)
    # print('after Validing:', valid_dataset.shape, valid_label.shape)
    # print('after Testing:', test_dataset.shape, test_label.shape)
    return train_dataset,train_label,valid_dataset,valid_label,test_dataset,test_label



# #测试生成的数据正确不
# def test(train_dataset,train_label):
#     print(train_label[:10])
#     #plt.figure(figsize=(50,20))
#     for i in range(10):
#         plt.subplot(5,2,i+1)
#         plt.imshow(train_dataset[i].reshape(28,28))
#     plt.show()


# if __name__ == '__main__':
#     test(train_dataset,train_label)


你可能感兴趣的:(手写字母数据集转换为.pickle文件)