Pytorch自定义Dataset和使用DataLoader装载数据

本文以lfw人脸数据库为例,用Pytorch自定义Dataset和使用DataLoader装载人脸图像。Dataset主要功能是读取数据源,而DataLoader在Dataset基础上组织数据供给深度算法使用,比如对图像的分批、shuaffle、扩展样本等操作。本文用的图片放在facesBmp目录下面,如下图所示:

Pytorch自定义Dataset和使用DataLoader装载数据_第1张图片

下面是实现代码。代码比较简单,可以看里面注释

# -*-coding: utf-8 -*-
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os,sys
from PIL import Image
import matplotlib.pyplot as plt

class  LfwDataset(Dataset):
    def __init__(self, image_dir, resize_height=64, resize_width=64):
        '''
        :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
        :param resize_height 为图像高,
        :param resize_width  为图像宽        
        '''
        # 所有图片的绝对路径
        imgs=os.listdir(image_dir)
        self.imgs=[os.path.join(image_dir,k) for k in imgs]
      # 相关预处理的初始化
      #  self.transforms=transform
        self.transforms=True
        self.transform= transforms.Compose([
        transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
         transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])  # 标准化至[-1,1]
])
 
    def __getitem__(self, i):
        img_path = self.imgs[i]
        pil_img = Image.open(img_path)
        if self.transforms:
            data =self.transform(pil_img)
        else:
            pil_img = np.asarray(pil_img)
            data = torch.from_numpy(pil_img)
        return data
 
    def __len__(self):
        return len(self.imgs)

   

 
if __name__=='__main__':
    image_dir="../facesBmp" #该文件夹下面直接是图像,与原始文件不一样,原始文件是有人名的二级目录
 
    epoch_num=1   #总样本循环次数 反对
    batch_size=1000  #训练时的一组数据的大小
    train_data_nums=13233
    max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数 
    train_data = LfwDataset(image_dir=image_dir)
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
 
    # [1]使用epoch方法迭代,LfwDataset的参数repeat=1
    for epoch in range(epoch_num):
        for batch_image in train_loader:
            image=batch_image[0,:]
            image=image.numpy()#       
            # plt.imshow(image)
            # plt.show()
            print("batch_image.shape:{}".format(batch_image.shape))    
 
    '''
    以两种方式实现训练集迭代,退出循环由max_iterate设定
    '''
    train_data = LfwDataset(image_dir=image_dir)
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
    # [2]第2种迭代方法
    print('第2种迭代方法')
    print(enumerate(train_loader))
    for step, batch_image in enumerate(train_loader):
        image=batch_image[0,:]
        image=image.numpy()#image=np.array(image)    
        # plt.imshow(image)
        # plt.show()
        print("step:{},batch_image.shape:{}".format(step,batch_image.shape))
        if step>=max_iterate:
            break

 代码中使用了两种方法迭代训练数据集,输出结果如下图所示:

Pytorch自定义Dataset和使用DataLoader装载数据_第2张图片

 

 

 

你可能感兴趣的:(深度技术学习)