深度学习:tf读取图片组成dataset形式

from glob import glob 
path = glob('./data/potato_data/*/*') # 所有的图片路径
label = [i.split('\\')[1] for i in path] # 所有图片对应的标签
label_dict = {'Early_blight':0, 'Late_blight':1, 'healthy':2}
label = [label_dict[i] for i in label ] # 把字符标签转成数值
# 通过构建DataSET的方式读取数据
train = tf.data.Dataset.from_tensor_slices( (path, label) )
for i,j in train:
    print(i,j)
    break
from tensorflow.io import read_file
def process_image(fpath, label):
    img = read_file(fpath)#编码后的数据
    img = tf.image.decode_png(img)/255 # 解码成图像数组
    img = tf.image.resize(img, [256,256]) # 所有图片大小统一
    label = tf.one_hot(label, depth=3) # 独热编码
    return img,label

# 通过映射,对x,y做处理
train = train.map(process_image)
for i,j in train:
    print(i,j)
    break

tarin = train.shuffle(10000) # 打乱数据
train = tarin.batch(32) # 给每个数据加批次
train.cache()  # 数据缓存
train.prefetch(buffer_size=tf.data.AUTOTUNE) # 预取数,增加资源使用效率
# 划分数据
num = tf.data.experimental.cardinality(train) # 所有批次
valdb = train.take(num//20) # 取3个批次,3
traindb = train.skip(num//20) # 跳过3个批次.65

你可能感兴趣的:(深度学习,人工智能)