资源路径:
链接:https://pan.baidu.com/s/1aX-PjuubzPfY6-6nG3ayIg
提取码:jhk1
import os
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
%matplotlib inline
print(tf.__version__)
print(np.__version__)
# 读取文件夹文件名与标签
def load_sample(sample_dir):
print("加载图片数据")
file_name_list = []
labels_names = []
for(dir_path, dir_names, file_names) in os.walk(sample_dir):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
# 获取图片路径与文件夹名字(标签)
file_name_list.append(file_path)
labels_names.append(dir_path.split("\\")[-1])
# 文件夹名字 标签去重 排序 ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
lab = list(sorted(set(labels_names)))
# 文件夹名字 分类编号 {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
# 字符串与 数字建立关系
labdict = dict(zip(lab, list(range(len(lab)))))
# 根据 文件夹名字 --> 标签(数字)
labels = [labdict[i] for i in labels_names]
return shuffle(np.asarray(file_name_list), np.asarray(labels)), np.asarray(lab)
# 读取文件名与标签
data_dir = 'mnist_digits_images\\' # 定义文件路径
(images, labels), labelsnames = load_sample(data_dir) # 载入文件名称与标签
print(len(images), images[:2]) # 文件名
print(len(labels), labels[:2]) # 标签
print(labelsnames[labels[:2]], labelsnames) # 标签字符串
def pre_process(x, y):
# x: 图片的路径,y:图片的数字编码
x = tf.io.read_file(x)
x = tf.image.decode_bmp(x, channels=1)
x = tf.cast(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y)
return x, y
def show_result(subplot, title, thisimg):
"""显示单个图片"""
p = plt.subplot(subplot)
p.axis('off')
# p.imshow(np.asarray(thisimg[0], dtype='uint8'))
p.imshow(np.reshape(thisimg, (28, 28)))
p.set_title(title)
def show_img(index, label, img, ntop):
"""显示图片列表"""
plt.figure(figsize=(20, 10))
plt.axis('off')
ntop = min(ntop, 9)
print("setp:", step.numpy())
for i in range(ntop):
show_result(100 + 10 * ntop + 1 + i, str(label[i].numpy()), img[i])
plt.show()
# 根据图片路径标签读取图片
batch_size = 10
db = tf.data.Dataset.from_tensor_slices((image, label))
db = db.shuffle(1000).map(pre_process).batch(batch_size)
# 循环显示图片
for step, (x, y) in db.enumerate():
print(x.shape, y.shape)
print(y.numpy())
show_img(step, y, x, batch_size) # 显示图片
if step >= 1:
break