Pytorch-DataLoader笔记

1. DataLoader

用于构建可迭代的数据加载器,训练时每一个iteration就是从DATa Loader中获取一个batch_size大小的数据。

参数

  • dataset: Dataset类。是要自定义编写的,继承自torch.utils.data.Dataset的类。
  • batchsize
  • num_works
  • shuffle
  • drop_last:当样本数不是batchsize的整数倍时,是否舍弃最后一组数据

2. Dataset

所有自定义的Dataset都要继承于torch.utils.data.Dataset,并且必须复写__getitem__()方法。

ImageFolder

ImageFolder假设所有的文件按文件夹保存,每个文件夹下存放同一类数据,文件夹名为类名。

参数
  • root:图片路径
  • transform:对PIL Image进行的转换参数
  • target_transform:对标签进行转换
  • loader:如何读取图片,默认读取为RGB格式的 PIL Image对象

3. 构造数据集示例

  1. 文件夹格式
train_path = r'dataset/train'
  1. 预处理
train_transform = transforms.Compose([
	transforms.Resize((64, 64)),   # 数据格式转换
    transforms.RandomCrop(40,padding=4),  # 随机裁剪
    transforms.ToTensor(),  # 声明为张量,便于pytorch计算
    transforms.Normalize([0.485,0.456,0.406],   
                         [0.229,0.224,0.225],)   # 对数据按通道进行标准化
])

transforms的各种方法大全:https://zhuanlan.zhihu.com/p/53367135

  1. 自定义Dataset
class myData(Data.Dataset): # 继承自抽象类
	def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.data_info = self.get_img_info(path)   # 将path文件夹中的数据以元组的列表形式保存
        self.label = []  
        for i in range(len(self.data_info)):
            self.label.append(list(self.data_info[i])[1])   # 保存标签,与data_info对齐

    def __getitem__(self, idx):
    	# 复写,根据idx返回 图像数据(转为张量), 标签, 索引
        path_img = self.data_info[idx][0]
        label = self.label[idx]
        img = Image.open(path_img).convert('RGB')  # 0~255
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor等等
        return img, label, idx

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    # 声明为静态方法, 即可以实例化调用也可以不实例化直接调用函数。
    def get_img_info(data_dir): 
    	'''
		:return :返回一个列表,列表的每个元素是一个元组(图片地址, 图片标签)
		'''
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = int(sub_dir)
                    data_info.append((path_img, int(label)))
        return data_info
  1. 创建DataLoader
# 数据集
trainset = myData(
	train_path,
	train_transform
)
# 数据发生器
train_loader = Data.DataLoader(
	dataset=trainset, 
	batch_size=4,
	shuffle = True
)

你可能感兴趣的:(pytorch,深度学习,python)