1 cifar_10图片文件的读取和写入以及安装tensorflow踩的坑

1 读取文件

import urllib
import os
import sys
import tarfile
import glob
import pickle
import numpy as np
import cv2

def download_and_uncompress_tarball(tarball_url, dataset_dir):
  """Downloads the `tarball_url` and uncompresses it locally.
  Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
  """
  filename = tarball_url.split('/')[-1]
  filepath = os.path.join(dataset_dir, filename)

  def _progress(count, block_size, total_size):
    sys.stdout.write('\r>> Downloading %s %.1f%%' % (
        filename, float(count * block_size) / float(total_size) * 100.0))
    sys.stdout.flush()
  filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
  print()
  statinfo = os.stat(filepath)
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dataset_dir)

classification = ['airplane',
                  'automobile',
                  'bird',
                  'cat',
                  'deer',
                  'dog',
                  'frog',
                  'horse',
                  'ship',
                  'truck']
#默认的图片解压缩形式
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_DIR = 'data'


# download_and_uncompress_tarball(DATA_URL, DATA_DIR)

folders = r'E:\zhuomian\tf_read_write\data_manager/data/cifar-10-batches-py'
#通过golb确定当前的图片位置
trfiles = glob.glob(folders + "/data_batch*")

data  = []
labels = []
for file in trfiles:
    dt = unpickle(file)
    print(dt)
    #解析出data和对应的lables
    data += list(dt[b"data"])
    labels += list(dt[b"labels"])
#labels相当于所有图片对应的类别

#将图片解析成第i个,3通道,32*32的图片
imgs = np.reshape(data, [-1, 3, 32, 32])

for i in range(imgs.shape[0]):
    #拿到数据
    im_data = imgs[i, ...]
    #转换维度,将通道放在最后边
    im_data = np.transpose(im_data, [1, 2, 0])
    #将RGB转为BGR 方便opencv读取
    im_data = cv2.cvtColor(im_data, cv2.COLOR_RGB2BGR)
    #通过类别来命名文件名 label[i] 拿到的是类别Id
    f = "{}/{}".format(r"E:\zhuomian\tf_read_write\data_manager/data/image/train", classification[labels[i]])
    #如果文件不存在的话就创建新的文件
    if not os.path.exists(f):
        os.mkdir(f)
    #写入图片
    cv2.imwrite("{}/{}.jpg".format(f, str(i)), im_data)

2.图片写入

import tensorflow as tf
import cv2
import numpy as np
classification = ['airplane',
                  'automobile',
                  'bird',
                  'cat',
                  'deer',
                  'dog',
                  'frog',
                  'horse',
                  'ship',
                  'truck']

import glob
#label
idx = 0
#所有文件夹下的图片
im_data = []
#所有文件夹下的类别
im_labels = []
for path in classification:
    path = r"E:\zhuomian\tf_read_write\data_manager/data/image/train" + path
    #通过glob获取当前文件夹下得图片,im_list是一个文件夹下的所有图片
    im_list = glob.glob(path + "/*")
    #设置文件夹图片数量相同的,im_label一个文件夹下所有图片的lable
    im_label = [idx for i in  range(im_list.__len__())]
    idx += 1
    im_data += im_list
    im_labels += im_label

print(im_labels)
print(im_data)

# tfrecord_file = r"E:\zhuomian\tf_read_write\data_manager/data/test.tfrecord"
# writer = tf.python_io.TFRecordWriter(tfrecord_file)
#
# index = [i for i in range(im_data.__len__())]
#
# np.random.shuffle(index)
#
# for i in range(im_data.__len__()):
#     im_d = im_data[index[i]]
#     im_l = im_labels[index[i]]
#     data = cv2.imread(im_d)
#     #data = tf.gfile.FastGFile(im_d, "rb").read()
#     ex = tf.train.Example(
#         features = tf.train.Features(
#             feature = {
#                 "image":tf.train.Feature(
#                     bytes_list=tf.train.BytesList(
#                         value=[data.tobytes()])),
#                 "label": tf.train.Feature(
#                     int64_list=tf.train.Int64List(
#                         value=[im_l])),
#             }
#         )
#     )
#     writer.write(ex.SerializeToString())
#
# writer.close()

3 然后遇到了tensorflow安不上

  • 首先pip说找不到tensorflow的包
  • 然后我去查了下是自己python是32位,然后就要换64位的
  • 我发现自己用的是Anaconda32位
  • 我就去安装Anaconda64位,更改环境变量
  • 以为可以开心安装,但是download的速度太慢了
  • 我就去找了豆瓣源的安装tensorflow
 pip install -i https://pypi.doubanio.com/simple/ tensorflow

  • 但是有有一个坑有一个包的版本太低
  Found existing installation: wrapt 1.10.11
ERROR: Cannot uninstall 'wrapt'. It is a distutils installed project and thus we cannot accurately determine which files belong to it which would lead to only a partial uninstall.
  • 没得办法继续更新包
 pip install -U --ignore-installed wrapt enum34 simplejson netaddr
  • 然后本宝宝终于安装成功了

你可能感兴趣的:(机器学习入门)