第一步,定义一个dataset,代码如下:
import torch.utils.data as data
import PIL.Image as Image
from glob import glob
class readImageDataset(data.Dataset):
def __init__(self, state, target_transform = None, input_transform = None):
self.state = state
self.root = r"图像文件夹的所在路径" #该文件夹内包含train、val以及test三个文件夹
self.targets, self.inputs = self.getDataPath()
self.train_target_paths, self.val_target_paths, self.test_target_paths = None, None, None
self.train_input_paths, self.val_input_paths, self.test_input_paths = None, None, None
self.target_transform = target_transform
self.input_transform = input_transform
# 定义一个读取数据集的子类
def getDataPath(self):
self.train_target_paths = glob(self.root + r'\train\target_img\*')
self.train_input_paths = glob(self.root + r'\train\input_img\*')
self.val_target_paths = glob(self.root + r'\val\target_img\*')
self.val_input_paths = glob(self.root + r'\val\input_img\*')
self.test_target_paths = glob(self.root + r'\test\target_img\*')
self.test_input_paths = glob(self.root + r'\test\input_img\*')
assert self.state == 'train' or self.state == 'val' or self.state == 'test' #用于选取所读指定的数据集
if self.state == 'train':
return self.train_target_paths, self.train_input_paths
if self.state == 'val':
return self.val_target_paths, self.val_input_paths
if self.state == 'test':
return self.test_target_paths, self.test_input_paths
def __getitem__(self, index):
target_path = self.targets[index]
input_path = self.inputs[index]
target = Image.open(target_path)
input = Image.open(input_path)
if self.target_transform is not None:
target = self.target_transform(target)
if self.input_transform is not None:
input = self.input_transform(input)
return target, input, target_path, input_path
def __len__(self):
return len(self.inputs)
第二步,定义一个Dataloader,并调用,代码如下:.
import argparse
import torch
from torch.utils.data import Dataloader
from dataset import *
from torchvision.transforms import transforms
#定义参数函数getArgs
def getArgs():
parse = argparse.ArgumentParser()
parse.add_argument("--batch_size", type=int, default=1)
parse.add_argument("--shuffle", default= False)
parse.add_argument("--num_workers", type = int, default = 4)
parse.add_argument('--dataset', default='grayImage', help='grayImage&colorImage')
args = parse.parse_args()
return args
def getDataset(args):
train_dataloaders, val_dataloaders, test_dataloaders = None, None, None
if args.dataset == 'grayImage':
train_dataset = readImageDataset(r'train', target_transform = target_transform, input_transform = input_transform)
train_dataloaders = Dataloader(train_dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
val_dataset = readImageDataset(r'val', target_transform = target_transform, input_transform = input_transform)
val_dataloaders = DataLoader(val_dataset, batch_size=args.batch_size, shuffle= True, num_workers = args.num_workers)
test_dataset = readImageDataset(r'test', target_transform = target_transform, input_transform = input_transform)
test_dataloaders = DataLoader(test_dataset, batch_size=1, shuffle= False, num_workers = args.num_workers)
return train_dataloaders, val_dataloaders, test_dataloaders
if __name__ == "__main__":
target_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 1为灰度图, 3为彩色图
transforms.Resize((100, 100), interpolation=Image.BICUBIC), # interpolation 改变图像大小,所使用的插值方式
transforms.ToTensor()
])
input_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((100, 100), interpolation=Image.BICUBIC),
transforms.ToTensor()
])
args = getArgs() # 获取参数
train_dataloaders, val_dataloaders, test_dataloaders = getDataset(args) //获取数据集