工作和学习中设计一个神经网络中经常需要设计一个数据载入器。首先第一件事我们要根据我们的任务要求确定一个数据提供的方法。如我们是一个分类任务,我们就需要读取数据和数据本身对应的标签。
1 2
除了分类任务之外当然还有一些图像到图像的任务,如超分辨率重建,图像去噪等任务那么对应的标签就是一张高分辨率的图像或清晰的无噪声图像。
第二件事就是根据我们的数据格式来确定数据的读取方式,以分类为例,每个文件夹下面的图像对应的为一个类别的图像的时候我们可以依次读取每个文件,并将每个文件编码成对应的0到n个类别。可以根据opencv,PIL等库读取图像opencv读取的是BGR格式的numpy数组,而PIL读取的是Image的对象。
import cv2
import PIL.Image as Im
import numpy as np
im=cv2.imread('./data_dir')
#转换成rgb
im=cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
#将数据转换成Image对象
im=Im.fromarray(im).convert('RGB')
#Image 直接读取图片
im=Im.open('./data_dir','rgb')
#将Image的对象转换成numpy数组
im=np.asarray(im)
当然你的文件也可能是mat文件或者npy件或者h5py文件:
import scipy.io as si
import h5py
import numpy as np
#读取npy文件
data=np.load('test.npy')
#保存npy文件
np.save('./test.npy',data)
#读取h5py文件
f=h5py.File('./test.h5','r')#以读的方式打开文件可以根据字典的键值获取数据
data=f['data']
#保存h5文件
f=h5py.File('./test.h5','w')
f['data']=im
f['label']=label
f.cloase()
#读取mat文件mat和h5类似都是字典格式
data=si.loadmat('test.mat')
im=data['x']
label=data['y']
#保存mat文件
si.savemat('test.mat',{'x':im,'y':label})
不论是哪种数据格式我们都要考虑一个问题我们的数据量是一个怎样的数量级,如果数据集过大我们没有那么多的内存就会遇到超内存的问题。如果是小数据集我们可以直接一次性读取。大数据一般按照分批次读取或者特殊的数据格式来读取。
import os
import cv2
import numpy as np
#有时候我们需要将图片随机裁剪
def random_crop(image_ref,image_dis,num_output,size):
h,w=image_ref.shape[:2]
random_h=np.random.randint(h-size,size=num_output)
random_w=np.random.randint(w-size,size=num_output)
patches_dis=[]
patches_ref=[]
for i in range(num_output):
patch_dis=image_dis[random_h[i]:random_h[i]+size,random_w[i]:random_w[i]+size]
patch_ref=image_ref[random_h[i]:random_h[i]+size,random_w[i]:random_w[i]+size]
patches_ref.append(patch_ref)
patches_dis.append(patch_dis)
return patches_ref,patches_dis
def read_data(path):
file_name=os.listdir(path)#获取所有文件的文件名称
data=[]
labels=[]
for idx,fn in enumerate(file_name):#以idx作为标签如果标签是图片则以另外的函数读取
im_dirs=path+'/'+fn
im_path=os.listdir(im_dirs)#读取每个文件夹下所有图像的名称
for n in im_path:
im=cv2.imread(im_dirs+'/'+n)
data.append(im)
labels.append(idx)
return np.asarray(data),np.asarray(labels)
#一次性读取所有的数据
data,labels= read_data(data_dir)
#将数据集乱序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]
#将数据集的80%划分为训练集
s=int(num_example*0.8)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]
#按照批次将数据送入模型中
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
assert len(inputs) == len(targets)
if shuffle:
indices = np.arange(len(inputs))
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
if shuffle:
excerpt = indices[start_idx:start_idx + batch_size]
else:
excerpt = slice(start_idx, start_idx + batch_size)
yield inputs[excerpt], targets[excerpt]
for x,y in minibatches(x_train,y_train,128,shuffle=False):
feed_dict={x1:x,y1:y}
上面的方法是一次性读取所有数据的,我们有时处理大数据的问题时就需要按照批次来读取了,这里推荐两种方法一种是基于tensorflow的tfrecords文件或者pytorch的Imagefolder两种方法:这里我们以这个数据集为例:http://download.tensorflow.org/example_images/flower_photos.tgz
是一个关于花分类的数据集:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
data_dir = 'E:/PytorchData/flower_photos'
def load_split_train_test(data_dir,valid_size = 0.2):
train_trainsforms = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor(),])
test_trainsforms = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor(),])
train_data = datasets.ImageFolder(datadir,transform=train_trainsforms)
test_data = datasets.ImageFolder(datadir,transform=test_trainsforms)
num_train = len(train_data) # 训练集数量
indices = list(range(num_train)) # 训练集索引
split = int(np.floor(valid_size * num_train)) # 获取20%数据作为验证集
np.random.shuffle(indices) # 打乱数据集
from torch.utils.data.sampler import SubsetRandomSampler
train_idx, test_idx = indices[split:], indices[:split] # 获取训练集,测试集
train_sampler = SubsetRandomSampler(train_idx) # 打乱训练集,测试集
test_sampler = SubsetRandomSampler(test_idx)
#============数据加载器:加载训练集,测试集===================
train_loader = DataLoader(train_data,sampler=train_sampler,batch_size=64)
test_loader = DataLoader(test_data,sampler=test_sampler,batch_size=64)
return train_loader,test_loader
train_loader,test_loader = load_split_train_test(data_dir, 0.2)
for inputs,labels in train_loader:
#这里inputs,和labels输出的Tensor我们想看到输出的结果需要转换成numpy数组
inputs,labels=np.asarray(inputs),np.asarray(labels)
print(inputs.shape)
#在pytorch中我们经常将数据放入到GPU中我们直接打印出来数据时会报错因此,我们需要将数据放入cpu中转换成numpy数组
除了pytorch之外还有tensorflow也提供了专门的数据接口,如常用的tfrecords,首先我们需要将自己的数据集保存成tfrecords文件
import os
import tensorflow as tf
from PIL import Image #注意Image,如果是cv2需要转换成Image对象
import matplotlib.pyplot as plt
import numpy as np
data_dir='E:/PytorchData/flower_photos/'
classes={'1','2','3','4','5'} #将花数据改成1到5个类别
writer= tf.python_io.TFRecordWriter("./flower_classfication.tfrecords") #要生成的文件
for index,name in enumerate(classes):
class_path=data_dir+name+'/'
for img_name in os.listdir(class_path):
img_path=class_path+img_name #每一个图片的地址
img=Image.open(img_path)
img= img.resize((128,128))
img_raw=img.tobytes()#将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) #example对象对label和image数据进行封装
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
在制作完成我们的数据集后需要读取:
import tensorflow as tf
def read_and_decode(filename): # 读入flower_classfication.tfrecords
filename_queue = tf.train.string_input_producer([filename])#生成一个queue队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})#将image数据和label取出来
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [128, 128, 3]) #reshape为128*128的3通道图片,必须和保存的分辨率一致
#否则出错,此外如果需要resize需要在下面调用tf.image.resize_images()
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中抛出img张量,并归一化减去0.5
label = tf.cast(features['label'], tf.int32) #在流中抛出label张量
return img, label
with tf.Session as sess:
a,b=read_and_decode('./flower_classfication.tfrecords)
for i inn range(100)
img,labels=sess.run([a,b])
之后还有tf.data读取数据我还没用过,用过之后更新!
参考博客:https://blog.csdn.net/wsp_1138886114/article/details/87809544
https://www.cnblogs.com/denny402/p/6931338.html