h5py批量写入文件、读取文件,支持任意维度的数据

1、创建引入库并创建h5文件

import h5py
import numpy as np

file_name='data.h5'
h5f=h5py.File(file_name)

2、批量写入数据的方法(支持任意维度的数据)一直追加数据到h5文件中

def save_h5(h5f,data,target):
    shape_list=list(data.shape)
    if not h5f.__contains__(target):
        shape_list[0]=None #设置数组的第一个维度是0
        dataset = h5f.create_dataset(target, data=data,maxshape=tuple(shape_list), chunks=True)
        return
    else:
        dataset = h5f[target]
    len_old=dataset.shape[0]
    len_new=len_old+data.shape[0]
    shape_list[0]=len_new 
    dataset.resize(tuple(shape_list)) #修改数组的第一个维度
    dataset[len_old:len_new] = data  #存入新的文件

3、调用批量写入的方法   (注意data一定要装换成np的数组,不然是没有shape属性的)

features=np.arange(100)
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')

4、批量读取

def getDataFromH5py(fileName,target,start,length):
    h5f=h5py.File(fileName)
    if not h5f.__contains__(target):
        res=[]
    elif(start+length>=h5f[target].shape[0]):
        res=h5f[target].value[start:h5f[target].shape[0]]
    else:
        res=h5f[target].value[start:start+length]
    h5f.close()
    return res
for i in range(10):
    d=getDataFromH5py('data.h5','minist_feature',i*5,5)#每批读取5个数据
    print(d)
h5f.colse

5、运行效果

h5py批量写入文件、读取文件,支持任意维度的数据_第1张图片

6、全部代码

#-*- coding: utf-8 -*-
import h5py
import numpy as np
def save_h5(h5f,data,target):
    shape_list=list(data.shape)
    if not h5f.__contains__(target):
        shape_list[0]=None
        dataset = h5f.create_dataset(target, data=data,maxshape=tuple(shape_list), chunks=True)
        return
    else:
        dataset = h5f[target]
    len_old=dataset.shape[0]
    len_new=len_old+data.shape[0]
    shape_list[0]=len_new
    dataset.resize(tuple(shape_list))
    dataset[len_old:len_new] = data
def getDataFromH5py(fileName,target,start,length):
    h5f=h5py.File(fileName)
    if not h5f.__contains__(target):
        res=[]
    elif(start+length>=h5f[target].shape[0]):
        res=h5f[target].value[start:h5f[target].shape[0]]
    else:
        res=h5f[target].value[start:start+length]
    h5f.close()
    return res

file_name='data.h5'
h5f=h5py.File(file_name)
features=np.arange(100)
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')
save_h5(h5f,data=np.array(features),target='mnist_features')
for i in range(10):
    d=getDataFromH5py('data.h5','mnist_features',i*5,5)#每批读取5个数据
    print(d)
h5f.close()

 

你可能感兴趣的:(python)