H5py数据集的制作

import os
import glob
import h5py
import scipy
import scipy.ndimage
import numpy as np
#读取文件到data
def prepare(data_dir):
    data_dir=os.path.join(os.getcwd(),data_dir)
    data=glob.glob(os.path.join(data_dir,'*.bmp'))
    return data

def make_data(data,label,data_dir):
    savepath=os.path.join('.',os.path.join('checkpoint',data_dir,'train.h5'))
    if not os.path.exists(os.path.join('.',os.path.join('checkpoint',data_dir))):
        os.makedirs(os.path.join('.',os.path.join('checkpoint',data_dir)))
    with h5py.File(savepath,'w') as hf:
        hf.create_dataset('data',data=data)
        hf.create_dataset('label',data=label)

def imread(path,is_grayscale=True):#这里的imread可以有很多的读取方式,根据自己的文件特性进行选择
    return scipy.misc.imread(path).astype(np.float)

def setup(data_dir):
    data=prepare(data_dir)
    input=[]
    input_label=[]
    for i in range(len(data)):
        input_=imread(data[i]
        label_=input_
        h,w=input_.shape#这没考虑图片的通道数,只有大小
        for x in range(0,h,stride):
            for y in range(0,w,stride):
                sub_input=input_[x:x+132,y:y+132]
                sub_label=label_[x:x+126,y:y+126]
                #可以添加reshape来强制修改图片以及增加通道
                input.append(sub_input)
                input_label.append(sub_label)
    ar_data=np.asarray(input)
    ar_label=np.asarray(input_label)
    make_data(ar_data,ar_label,data_dir)

hy=setup('data_dir')
#读取h5数据
with h5py.File('./checkpoint/train_ir/xx.h5','r') as hf:
    data = np.array(hf.get('data'))
    label = np.array(hf.get('label'))

可以在此基础上进行修改得到符合要求的数据集。 

你可能感兴趣的:(python,pycharm)