tensorflow 读取文件夹图片(tf2.0)

tensorflow 读取文件夹图片(tf2.0)

资源路径:
链接:https://pan.baidu.com/s/1aX-PjuubzPfY6-6nG3ayIg
提取码:jhk1

import os
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
%matplotlib inline

print(tf.__version__)
print(np.__version__)
# 读取文件夹文件名与标签
def load_sample(sample_dir):
    
    print("加载图片数据")
    file_name_list = []
    labels_names = []
    
    for(dir_path, dir_names, file_names) in os.walk(sample_dir):
        
        for file_name in file_names:
            file_path = os.path.join(dir_path, file_name)
            # 获取图片路径与文件夹名字(标签)
            file_name_list.append(file_path)    
            labels_names.append(dir_path.split("\\")[-1])
    
    # 文件夹名字 标签去重 排序  ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    lab = list(sorted(set(labels_names)))
    
    # 文件夹名字 分类编号 {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
    # 字符串与 数字建立关系
    labdict = dict(zip(lab, list(range(len(lab)))))
    
    # 根据 文件夹名字 --> 标签(数字)
    labels = [labdict[i] for i in labels_names]
    
    return shuffle(np.asarray(file_name_list), np.asarray(labels)), np.asarray(lab)

# 读取文件名与标签 
data_dir = 'mnist_digits_images\\'  # 定义文件路径

(images, labels), labelsnames = load_sample(data_dir)  # 载入文件名称与标签
print(len(images), images[:2])  # 文件名 
print(len(labels), labels[:2])  # 标签              
print(labelsnames[labels[:2]], labelsnames)  # 标签字符串
def pre_process(x, y):
    # x: 图片的路径,y:图片的数字编码
    x = tf.io.read_file(x)
    x = tf.image.decode_bmp(x, channels=1) 
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.convert_to_tensor(y)

    return x, y
def show_result(subplot, title, thisimg):  
    """显示单个图片"""
    
    p = plt.subplot(subplot)
    p.axis('off')
    # p.imshow(np.asarray(thisimg[0], dtype='uint8'))
    p.imshow(np.reshape(thisimg, (28, 28)))
    p.set_title(title)


def show_img(index, label, img, ntop):
    """显示图片列表"""
    plt.figure(figsize=(20, 10))
    plt.axis('off')
    ntop = min(ntop, 9)
    print("setp:", step.numpy())
    
    for i in range(ntop):
        show_result(100 + 10 * ntop + 1 + i, str(label[i].numpy()), img[i])
    plt.show()
# 根据图片路径标签读取图片
batch_size = 10
db = tf.data.Dataset.from_tensor_slices((image, label))
db = db.shuffle(1000).map(pre_process).batch(batch_size)

# 循环显示图片
for step, (x, y) in db.enumerate():    
    print(x.shape, y.shape)
    print(y.numpy())
    show_img(step, y, x, batch_size)  # 显示图片
    
    if step >= 1:
        break

tensorflow 读取文件夹图片(tf2.0)_第1张图片

你可能感兴趣的:(TensorFlow)