Python3 转换 cifar100 数据集

Cifar官网给出的python接口的文件都是用python cPickle工具”pickled”的,可以看见 cifar 官网给出的例程是:

python 2

def unpickle(file):
    import cPickle
    with open(file, 'rb') as fo:
        dict = cPickle.load(fo)
    return dict

python 3:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

这里给出Python3的示例代码:

import os
import pickle

import numpy as np
import sklearn
import sklearn.linear_model

import lmdb
import caffe


def unpickle(file):
    fo = open(file, 'rb')
    dict = pickle.load(fo,encoding ='bytes')
    fo.close()
    return dict


# 调用sklearn对数据进行shuffle操作
def shuffle_data(data, labels):
    data, _, labels, _ = sklearn.cross_validation.train_test_split(
        data, labels, test_size=0.0, random_state=42
    )
    return data, labels


def load_data(train_file):
    d = unpickle(train_file)
       #dict_keys([b'batch_label', b'filenames', b'data', b'coarse_labels', b'fine_labels']),每个键值前面都有一个b,不同于 python2
    data = d[b'data']
    fine_labels = d[b'fine_labels']
    length = len(d[b'fine_labels'])

    data, labels = shuffle_data(
        data,
        np.array(fine_labels)
    )
    return (
        data.reshape(length, 3, 32, 32),
        labels
    )


if __name__ == '__main__':
    # 解压后的 cifar-100-python 路径
    cifar_python_directory = os.path.abspath(r'F:\Software_download\ChromeDownload\cifar-100-python.tar\cifar-100-python')

    print('Converting...')
    cifar_caffe_directory = os.path.abspath('cifar100_train_lmdb')
    if not os.path.exists(cifar_caffe_directory):
        X, y_f = load_data(os.path.join(cifar_python_directory, 'train'))
        Xt, yt_f = load_data(os.path.join(cifar_python_directory, 'test'))
        print('Data is fully loaded,now truly convertung.')
        # lmdb操作,将数据写入数据库
        env = lmdb.open(cifar_caffe_directory, map_size=50000 * 1000 * 5)
        txn = env.begin(write=True)
        count = 0
        for i in range(X.shape[0]):
            datum = caffe.io.array_to_datum(X[i], y_f[i])
            str_id = '{:08}'.format(count)
            # txn.put(str_id, datum.SerializeToString())
            txn.put(str_id.encode('ascii'), datum.SerializeToString())
            count += 1
            if count % 1000 == 0:
                print('already handled with {} pictures'.format(count))
                txn.commit()
                txn = env.begin(write=True)

        txn.commit()
        env.close()

        env = lmdb.open('cifar100_test_lmdb', map_size=10000 * 1000 * 5)
        txn = env.begin(write=True)
        count = 0
        for i in range(Xt.shape[0]):
            datum = caffe.io.array_to_datum(Xt[i], yt_f[i])
            str_id = '{:08}'.format(count)

            # python 3 在 str_id 后多了一个 .encode('ascii')
            txn.put(str_id.encode('ascii'), datum.SerializeToString())

            count += 1
            if count % 1000 == 0:
                print('already handled with {} pictures'.format(count))
                txn.commit()
                txn = env.begin(write=True)

        txn.commit()
        env.close()
    else:
        print('Conversion was already done. ')

—————————————————————————————

参考博客:http://blog.csdn.net/u010165147/article/details/54176612

你可能感兴趣的:(机器学习,python)