pytorch的数据集生成方法:基于h5py

自学pytorch生成对抗网络编程一书源码标注

  • 生成文件的方法
hdf5_file = './celeba_aligned_small.h5py'

total_images = 20000

with h5py.File(hdf5_file, 'w') as hf:  # 打开h5py文件,文件不存在则会创建文件

    count = 0

    with zipfile.ZipFile('celeba/img_align_celeba.zip', 'r') as zf:
    # 这个压缩文件里是一个文件夹img_align_celeba文件夹中有200000多张图片
        for i in zf.namelist():  # zf.namelist()返还压缩文件中的文件列表名
        # zf.namelist()[0]是'img_align_celeba/' 即文件夹路径
        # zf.namelist()[1]是'img_align_celeba/000001.jpg' 即文件夹下的文件路径
            if i[-4:] == '.jpg':
                ofile = zf.extract(i)  # 解压单个文件至ofile中
                # 默认解压在当前文件夹即在'./'路径下创建img_align_celeba文件夹,把图片(i)放入文件夹中
                # ofile是解压后图片(i)的相对地址是一个字符串
                img = imageio.imread(ofile)
                # 使用imageio.imread读取图片,此时img打印出来是一个数组
                os.remove(ofile)  # 用完即弃
                # 删除图片,不占存储空间

                hf.create_dataset('img_align_celeba/'+str(count)+'.jpg',
                                  data=img, compression='gzip', compression_opts=9)
                # compression是压缩方式, compression_opts是压缩程度的参数
                # 在celeba_aligned_small.h5py文件中生成组img_align_celeba,在组中保存img数组

                count = count + 1
                if count % 1000 == 0:
                    print('images done ...', count)

                if count == total_images:  # 只取前20000张图片
                    break
  • 使用文件
with h5py.File('./celeba_aligned_small.h5py', 'r') as file_object:
    # h5py文件只读打开创建文件对象file_object,此对象为可遍历对象
    for group in file_object:  # 对文件对象进行遍历得到组名称的字符串
        print(group)  # 'img_align_celeba'

# 从群组中导出数据集,并以索引的形式展示
with h5py.File('./celeba_aligned_small.h5py', 'r') as file_object:
    dataset = file_object['img_align_celeba']  # dataset即为h5py文件中的一个组
    # 对文件对象用组名称进行索引得到组对象
    image = np.array(dataset['8.jpg'])  # 在组中以文件名索引具体文件
    # 组对象用组中的文件名进行索引得到其中的数组,使用numpy转化为numpy数组
    print(image.shape)
    plt.imshow(image, interpolation='none')
    plt.show()

# 这就是从h5py中取出numpy图片张量的方法

  • 利用h5py生成自定义Dataset类
from torch.utils.data import Dataset

class CelebADataset(Dataset):  # 继承pytorch的Dataset类

    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')  # 文件对象
        self.dataset = self.file_object['img_align_celeba']  # 组对象

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):  # 从组对象中提取图片输出归一化的torch张量,图片格式为(h,w,3)
        if item >= len(self.dataset):  # 索引值大于等于长度报错
            raise IndexError()
        img = np.array(self.dataset[str(item)+'.jpg'])
        return torch.cuda.FloatTensor(img) / 255.
# 至此可以实现单张图片处理
dataset = CelebADataset(hdf5_file)  # 实例化
print(dataset[0])  # 使用索引直接调用__getitem__函数获得torch张量(h, w, 3)
# 进行批处理
train_iter = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=32)
for i in train_iter:
    # 小bug如果声明torch.set_default_tensor_type()
    # 这里会报错
    print(i.shape)  # torch.Size([32, 218, 178, 3])
    break
  • 使用h5py的优势
    图片不会全部加载到内存当中,节约内存开支,在训练时对数据即取即用,相较于普通的在文件夹中存储图片数据的方式,将图片文件压缩进入h5py文件中,即取即用的效率更高

你可能感兴趣的:(h5py,pytorch,人工智能)