【深度学习】Pytorch读取图像数据集

第一步,定义一个dataset,代码如下:

import torch.utils.data as data
import PIL.Image as Image
from glob import glob


class readImageDataset(data.Dataset):
	def __init__(self, state, target_transform = None, input_transform = None):
		self.state = state
		self.root = r"图像文件夹的所在路径"       #该文件夹内包含train、val以及test三个文件夹
		self.targets, self.inputs = self.getDataPath()
		self.train_target_paths, self.val_target_paths, self.test_target_paths = None, None, None
		self.train_input_paths, self.val_input_paths, self.test_input_paths = None, None, None
		self.target_transform = target_transform
		self.input_transform = input_transform
    
    # 定义一个读取数据集的子类
	def getDataPath(self):
		self.train_target_paths = glob(self.root + r'\train\target_img\*')
		self.train_input_paths = glob(self.root + r'\train\input_img\*')
		self.val_target_paths = glob(self.root + r'\val\target_img\*')
		self.val_input_paths = glob(self.root + r'\val\input_img\*')
		self.test_target_paths = glob(self.root + r'\test\target_img\*')
		self.test_input_paths = glob(self.root + r'\test\input_img\*')
		assert self.state == 'train' or self.state == 'val' or self.state == 'test'  #用于选取所读指定的数据集
		if self.state == 'train':
			return self.train_target_paths, self.train_input_paths
		if self.state == 'val':
			return self.val_target_paths, self.val_input_paths
		if self.state == 'test':
			return self.test_target_paths, self.test_input_paths
	
	def __getitem__(self, index):
		target_path = self.targets[index]
		input_path = self.inputs[index]
		target = Image.open(target_path)
		input = Image.open(input_path)
		if self.target_transform is not None:
			target = self.target_transform(target)
		if self.input_transform is not None:
			input = self.input_transform(input)
		return target, input, target_path, input_path
	
	def __len__(self):
		return len(self.inputs)

第二步,定义一个Dataloader,并调用,代码如下:.

import argparse
import torch
from torch.utils.data import Dataloader
from dataset import *
from torchvision.transforms import transforms


#定义参数函数getArgs
def getArgs():
	parse = argparse.ArgumentParser()
	parse.add_argument("--batch_size", type=int, default=1)
    parse.add_argument("--shuffle", default= False)
    parse.add_argument("--num_workers", type = int, default = 4)
    parse.add_argument('--dataset', default='grayImage', help='grayImage&colorImage')
    args = parse.parse_args()
    return args

def getDataset(args):
	train_dataloaders, val_dataloaders, test_dataloaders = None, None, None
	if args.dataset == 'grayImage':
		train_dataset = readImageDataset(r'train', target_transform = target_transform, input_transform = input_transform)
		train_dataloaders = Dataloader(train_dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
		val_dataset = readImageDataset(r'val', target_transform = target_transform, input_transform = input_transform)
	    val_dataloaders = DataLoader(val_dataset, batch_size=args.batch_size, shuffle= True, num_workers = args.num_workers)
	    test_dataset = readImageDataset(r'test', target_transform = target_transform, input_transform = input_transform)
	    test_dataloaders = DataLoader(test_dataset, batch_size=1, shuffle= False, num_workers = args.num_workers)
    return train_dataloaders, val_dataloaders, test_dataloaders



if __name__ == "__main__": 
	target_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),                   # 1为灰度图, 3为彩色图
        transforms.Resize((100, 100), interpolation=Image.BICUBIC),    # interpolation 改变图像大小,所使用的插值方式
        transforms.ToTensor()
    ])
    input_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((100, 100), interpolation=Image.BICUBIC),
        transforms.ToTensor()
    ])

	args = getArgs()      # 获取参数
	train_dataloaders, val_dataloaders, test_dataloaders = getDataset(args)  //获取数据集

你可能感兴趣的:(pytorch,深度学习,python,dataset)