import os
import glob
import h5py
import scipy
import scipy.ndimage
import numpy as np
#读取文件到data
def prepare(data_dir):
data_dir=os.path.join(os.getcwd(),data_dir)
data=glob.glob(os.path.join(data_dir,'*.bmp'))
return data
def make_data(data,label,data_dir):
savepath=os.path.join('.',os.path.join('checkpoint',data_dir,'train.h5'))
if not os.path.exists(os.path.join('.',os.path.join('checkpoint',data_dir))):
os.makedirs(os.path.join('.',os.path.join('checkpoint',data_dir)))
with h5py.File(savepath,'w') as hf:
hf.create_dataset('data',data=data)
hf.create_dataset('label',data=label)
def imread(path,is_grayscale=True):#这里的imread可以有很多的读取方式,根据自己的文件特性进行选择
return scipy.misc.imread(path).astype(np.float)
def setup(data_dir):
data=prepare(data_dir)
input=[]
input_label=[]
for i in range(len(data)):
input_=imread(data[i]
label_=input_
h,w=input_.shape#这没考虑图片的通道数,只有大小
for x in range(0,h,stride):
for y in range(0,w,stride):
sub_input=input_[x:x+132,y:y+132]
sub_label=label_[x:x+126,y:y+126]
#可以添加reshape来强制修改图片以及增加通道
input.append(sub_input)
input_label.append(sub_label)
ar_data=np.asarray(input)
ar_label=np.asarray(input_label)
make_data(ar_data,ar_label,data_dir)
hy=setup('data_dir')
#读取h5数据
with h5py.File('./checkpoint/train_ir/xx.h5','r') as hf:
data = np.array(hf.get('data'))
label = np.array(hf.get('label'))
可以在此基础上进行修改得到符合要求的数据集。