(1)Bringing Old Photos Back to Life原理和测试
(2)
Bringing Old Photos Back to Life模型代码分析1(数据载入部分)
Bringing Old Photos Back to Life模型代码分析2(模型部分)
(3)Bringing Old Photos Back to Life数据集及其训练
这一部分是关于数据预处理部分
文件在Global/data下,如图所示
base_dataset.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
#
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
#获取裁剪参数
# 这个函数是根据用户指定的方式resize或者crop出合适大小的输入尺寸。
# size:输入图片的尺寸
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
# opt.loadSize为自己输入的尺寸,将图像缩放到这个大小
new_h = new_w = opt.loadSize # 将宽和高设置为同样大小
if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256
if w 0.5 # 随机数是否大于0.5,flip是bool型变量,此行代码意思为随机生成True或者False
return {'crop_pos': (x, y), 'flip': flip} # 最终的返回值,在data.aligned_dataset 45行,当作params传入了下方get_transform()函数
# 图像变换
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
#重设置大小
if 'resize' in opt.resize_or_crop: # # 若opt.resize_or_crop中有'resize'
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
#
elif 'scale_width' in opt.resize_or_crop:
# transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it.
#将输入的`PIL.Image`重新改变大小成给定的`size`即256
transform_list.append(transforms.Scale(256,method))
#裁剪
if 'crop' in opt.resize_or_crop:
if opt.isTrain:
# 使用transforms.Lambda封装其为transforms策略
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
else:
if opt.test_random_crop:
transform_list.append(transforms.RandomCrop(opt.fineSize))
else:
transform_list.append(transforms.CenterCrop(opt.fineSize))
## when testing, for ablation study, choose center_crop directly.
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
# mean和std均为0.5
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
#归一化到(-1,1)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#将图片进行设置大小为base的整倍数
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
#修改图片为目标大小
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
#对图片进行切割# 随机平移滑动裁剪
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size # 输入的尺寸 opt.fineSize
if (ow > tw or oh > th):
#Image.crop(left, up, right, below) 其中left:与左边界的距离 up:与上边界的距离 right:还是与左边界的距离 below:还是与上边界的距离
return img.crop((x1, y1, x1 + tw, y1 + th)) # 随机裁剪,因为虽然每次裁剪测大小一样,但是起始点位置不一样
return img
#左右翻转
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
Create_Bigfile.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import struct
from PIL import Image
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
# 判断文件夹中是否有以上类型图片,没有则返回0
def is_image_file(filename):
#如果不都为空、0、false,则any()返回true
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
#创建图片数据集,存在列表中并返回
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
# os.walk(top[, topdown=True[, onerror=None[, followlinks=False]]]) 通过在目录树中游走输出在目录中的文件名,top返回三项(root,dirs,files),分别代表:
# 当前正在遍历的这个文件夹的本身的地址; list类型,内容是该文件夹中所有的目录的名字(不包括子目录); list类型,内容是该文件夹中所有的文件(不包括子目录)
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
#print(fname)
#拼接出图片的地址,并加入到images列表
path = os.path.join(root, fname)
images.append(path)
return images
### Modify these 3 lines in your own environment
#需要修改以下三个变量:
#变量一:存放待训练数据集文件夹的父目录
indir="/home/ziyuwan/workspace/data/temp_old"
#变量二:待训练数据的文件夹,共有三个目标文件夹,分别为 : VOC数据集(用于生成假老照片)、真实黑白老照片、真实彩色老照片
target_folders=['VOC','Real_L_old','Real_RGB_old']
#变量三:输出生成结果的文件夹路径
out_dir ="/home/ziyuwan/workspace/data/temp_old"
###
if os.path.exists(out_dir) is False:
os.makedirs(out_dir)
#遍历存放数据集的文件夹
for target_folder in target_folders:
#拼接生成存放数据集文件夹的路径
curr_indir = os.path.join(indir, target_folder)
#生成的大文件路径(含问文件名)
curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile' % (target_folder)))
image_lists = make_dataset(curr_indir)
image_lists.sort()
with open(curr_out_file, 'wb') as wfid:
# write total image number
wfid.write(struct.pack('i', len(image_lists)))
for i, img_path in enumerate(image_lists):
# write file name first
img_name = os.path.basename(img_path)
img_name_bytes = img_name.encode('utf-8')
wfid.write(struct.pack('i', len(img_name_bytes)))
wfid.write(img_name_bytes)
#
# # write image data in
with open(img_path, 'rb') as img_fid:
img_bytes = img_fid.read()
wfid.write(struct.pack('i', len(img_bytes)))
wfid.write(img_bytes)
if i % 1000 == 0:
print('write %d images done' % i)
custom_dataset_data_loader.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data
import random
from data.base_data_loader import BaseDataLoader
from data import online_dataset_for_old_photos as dts_ray_bigfile
#根据训练的模型模块不同,返回对应的数据集
def CreateDataset(opt):
dataset = None
# 训练A或者B时,使用的数据集为非成对数据集
if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':
dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
#当训练mapping时,载入成对数据集
if opt.training_dataset=='mapping':
if opt.random_hole:
dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
else:
dataset = dts_ray_bigfile.PairOldPhotos()
print("dataset [%s] was created" % (dataset.name())) # 打印数据集名字为‘
dataset.initialize(opt) # 初始化数据集参数
return dataset # 返回创建好的数据集
##创建数据载入器# 加载数据集
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt) ## 初始化参数
#创建数据集
self.dataset = CreateDataset(opt)
#创建数据载入器
self.dataloader = torch.utils.data.DataLoader( ## 加载创建好的数据集,并自定义相关参数
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=True)
def load_data(self):
return self.dataloader # 返回数据集
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)# 返回加载的数据集长度和一个epoch容许的加载最大容量
data_loader.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#创建数据载入器
##########################################################################
# 创建数据集加载主函数
########################################################################
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name()) # 返回的名字为“CustomDatasetDataLoader”
data_loader.initialize(opt) # # 初始化参数
return data_loader
image_foder.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
### any()函数用于判断给定的可迭代参数iterable是否全部为False,则返回False,如果有一个为True,则返回True。
# 元素除了是0、空、FALSE外都算TRUE。
# 函数等价于:
# def any(iterable):
# for element in iterable:
# if element:
# return True
# return False
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
# 制作数据集:获得数据集的图片路径列表
def make_dataset(dir): # dir为数据集文件夹路径
images = []# 创建空列表
assert os.path.isdir(dir), '%s is not a valid directory' % dir # 确认路径存在
### os.walk() 方法是一个简单易用的文件、目录遍历器,可以帮助我们高效的处理文件、目录方面的事情
# top -- 是你所要遍历的目录的地址, 返回的是一个三元组(root,dirs,files)。
# root 所指的是当前正在遍历的这个文件夹的本身的地址,和输入的os.walk(dir)种的dir一致
# dirs 是一个 list ,内容是该文件夹中所有的 目录 的名字(不包括子目录),若无则为[]
# files 同样是 list , 内容是该文件夹中所有的 文件 的名字(不包括子目录),若无则为[]
for root, _, fnames in sorted(os.walk(dir)): # fnames为文件中读取的照片文件
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)# 将文件夹路径dir 和 图片名称fname 结合起来
images.append(path) # 将图片路径存放到image列表里
return images # 返回图片路径列表
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root) # imgs为root目录下图片路径列表
if len(imgs) == 0: # 图片数量 = 0 报错
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index] # 获取指定图片路径
img = self.loader(path) # 加载图片
if self.transform is not None:
img = self.transform(img) # 图片进行变换
if self.return_paths:
return img, path # 返回图片和路径
else:
return img # 仅返回图片
def __len__(self):
return len(self.imgs) # 返回指定目录下图片数量
Load_Bigfile.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import io
import os
import struct
from PIL import Image
#载入打包好的数据
class BigFileMemoryLoader(object):
def __load_bigfile(self):
print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
with open(self.file_path, 'rb') as fid:
self.img_num = struct.unpack('i', fid.read(4))[0]
self.img_names = []
self.img_bytes = []
print('find total %d images' % self.img_num)
for i in range(self.img_num):
img_name_len = struct.unpack('i', fid.read(4))[0]
img_name = fid.read(img_name_len).decode('utf-8')
self.img_names.append(img_name)
img_bytes_len = struct.unpack('i', fid.read(4))[0]
self.img_bytes.append(fid.read(img_bytes_len))
if i % 5000 == 0:
print('load %d images done' % i)
print('load all %d images done' % self.img_num)
#初始化
def __init__(self, file_path):
super(BigFileMemoryLoader, self).__init__()
self.file_path = file_path
self.__load_bigfile()
#返回图片名字和图片
def __getitem__(self, index):
try:
img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
return self.img_names[index], img
except Exception:
print('Image read error for index %d: %s' % (index, self.img_names[index]))
return self.__getitem__((index+1)%self.img_num)
#图片数目
def __len__(self):
return self.img_num
online_dataset_for_old_photos.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os.path
import io
import zipfile
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from data.Load_Bigfile import BigFileMemoryLoader
import random
import cv2
from io import BytesIO
#图片转矩阵
def pil_to_np(img_PIL):
'''Converts image in PIL format to np.array.
From W x H x C [0...255] to C x W x H [0..1]
'''
ar = np.array(img_PIL)
if len(ar.shape) == 3:
ar = ar.transpose(2, 0, 1)
else:
ar = ar[None, ...]
return ar.astype(np.float32) / 255.
#矩阵转图片
def np_to_pil(img_np):
'''Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
'''
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
##
#以下合成噪声图片
##
def synthesize_salt_pepper(image,amount,salt_vs_pepper):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
out = img_pil.copy()
p = amount
q = salt_vs_pepper
flipped = np.random.choice([True, False], size=img_pil.shape,
p=[p, 1 - p])
salted = np.random.choice([True, False], size=img_pil.shape,
p=[q, 1 - q])
peppered = ~salted
out[flipped & salted] = 1
out[flipped & peppered] = 0.
noisy = np.clip(out, 0, 1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_gaussian(image,std_l,std_r):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
mean=0
std=random.uniform(std_l/255.,std_r/255.)
gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
noisy=img_pil+gauss
noisy=np.clip(noisy,0,1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_speckle(image,std_l,std_r):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
mean=0
std=random.uniform(std_l/255.,std_r/255.)
gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
noisy=img_pil+gauss*img_pil
noisy=np.clip(noisy,0,1).astype(np.float32)
return np_to_pil(noisy)
#图片缩小
def synthesize_low_resolution(img):
w,h=img.size
new_w=random.randint(int(w/2),w)
new_h=random.randint(int(h/2),h)
img=img.resize((new_w,new_h),Image.BICUBIC)
if random.uniform(0,1)<0.5:
img=img.resize((w,h),Image.NEAREST)
else:
img = img.resize((w, h), Image.BILINEAR)
return img
#处理图片
def convertToJpeg(im,quality):
#在内存中读写bytes
with BytesIO() as f:
im.save(f, format='JPEG',quality=quality)
f.seek(0)
#使用Image.open读出图像,然后转换为RGB通道,去掉透明通道A
return Image.open(f).convert('RGB')
#由(高斯)噪声生成图片
def blur_image_v2(img):
x=np.array(img)
kernel_size_candidate=[(3,3),(5,5),(7,7)]
kernel_size=random.sample(kernel_size_candidate,1)[0]
std=random.uniform(1.,5.)
#print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std))
blur=cv2.GaussianBlur(x,kernel_size,std)
return Image.fromarray(blur.astype(np.uint8))
#由以上噪声函数随机生成含有噪声的图片
def online_add_degradation_v2(img):
task_id=np.random.permutation(4)
for x in task_id:
if x==0 and random.uniform(0,1)<0.7:
img = blur_image_v2(img)
if x==1 and random.uniform(0,1)<0.7:
flag = random.choice([1, 2, 3])
if flag == 1:
img = synthesize_gaussian(img, 5, 50)
if flag == 2:
img = synthesize_speckle(img, 5, 50)
if flag == 3:
img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))
if x==2 and random.uniform(0,1)<0.7:
img=synthesize_low_resolution(img)
if x==3 and random.uniform(0,1)<0.7:
img=convertToJpeg(img,random.randint(40,100))
return img
#根据mask生成带有折痕的图片
#原论文中对于一些复杂的折痕会出现处理不佳的情况,在此进行改进,而不是简单进行加mask,
def irregular_hole_synthesize(img,mask):
img_np=np.array(img).astype('uint8')
mask_np=np.array(mask).astype('uint8')
mask_np=mask_np/255
img_new=img_np*(1-mask_np)+mask_np*255
hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB")
#L为灰度图像
return hole_img,mask.convert("L")
#生成全黑三通道图像mask
def zero_mask(size):
x=np.zeros((size,size,3)).astype('uint8')
mask=Image.fromarray(x).convert("RGB")
return mask
#非成对的老照片图像载入器(合成的老的和真实的老的照片,他们无需对应的,合成的老的照片由VOC数据集经处理生成)
class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old
def initialize(self, opt):
self.opt = opt
self.isImage = 'domainA' in opt.name
self.task = 'old_photo_restoration_training_vae'
self.dir_AB = opt.dataroot
# 载入VOC以及真实灰度、彩色图
if self.isImage:
self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile")
self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile")
self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old)
self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)
self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
else:
# self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)
self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
####
print("-------------Filter the imgs whose size <256 in VOC-------------")
self.filtered_imgs_clean=[]
# 过滤出VOC中小于256的图片
for i in range(len(self.loaded_imgs_clean)):
img_name,img=self.loaded_imgs_clean[i]
h,w=img.size
if h<256 or w<256:
continue
self.filtered_imgs_clean.append((img_name,img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
## Filter these images whose size is less than 256
# self.img_list=os.listdir(load_img_dir)
self.pid = os.getpid()
def __getitem__(self, index):
is_real_old=0
sampled_dataset=None
degradation=None
#随机抽取一张图片(从合成的老照片 和 真实老照片 中)
if self.isImage:
P=random.uniform(0,2)
if P>=0 and P<1:
if random.uniform(0,1)<0.5:
sampled_dataset=self.loaded_imgs_L_old
self.load_img_dir=self.load_img_dir_L_old
else:
sampled_dataset=self.loaded_imgs_RGB_old
self.load_img_dir=self.load_img_dir_RGB_old
is_real_old=1
if P>=1 and P<2:
sampled_dataset=self.filtered_imgs_clean
self.load_img_dir=self.load_img_dir_clean
degradation=1
else:
#载入过滤后小于256大小的图
sampled_dataset=self.filtered_imgs_clean
self.load_img_dir=self.load_img_dir_clean
sampled_dataset_len=len(sampled_dataset)
index=random.randint(0,sampled_dataset_len-1)
img_name,img = sampled_dataset[index]
if degradation is not None:
#对图片进行降质做旧处理
img=online_add_degradation_v2(img)
path=os.path.join(self.load_img_dir,img_name)
# AB = Image.open(path).convert('RGB')
# split AB image into A and B
# apply the same transform to both A and B
#随机对图片转换为灰度图
if random.uniform(0,1) <0.1:
img=img.convert("L")
img=img.convert("RGB")
## Give a probability P, we convert the RGB image into L
#调整大小
A=img
w,h=A.size
if w<256 or h<256:
A=transforms.Scale(256,Image.BICUBIC)(A)
# 将图片裁剪为256*256,对于一些小于256的老照片,先进行调整大小
## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
#存入字典
input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number
def name(self):
return 'UnPairOldPhotos_SR'
#成对图像载入器(原始图及其合成旧图)
class PairOldPhotos(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.isImage = 'imagegan' in opt.name
self.task = 'old_photo_restoration_training_mapping'
self.dir_AB = opt.dataroot
#训练模式,载入VOC
if opt.isTrain:
self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
print("-------------Filter the imgs whose size <256 in VOC-------------")
#过滤出VOC中小于256的图片
self.filtered_imgs_clean = []
for i in range(len(self.loaded_imgs_clean)):
img_name, img = self.loaded_imgs_clean[i]
h, w = img.size
if h < 256 or w < 256:
continue
self.filtered_imgs_clean.append((img_name, img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
#测试模式时,仅载入测试集
else:
self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
self.pid = os.getpid()
def __getitem__(self, index):
#训练模式
if self.opt.isTrain:
#(B为清晰VOC数据集)
img_name_clean,B = self.filtered_imgs_clean[index]
path = os.path.join(self.load_img_dir_clean, img_name_clean)
#生成成对图像(B为清晰VOC数据集,A对应的含噪声的图像)
if self.opt.use_v2_degradation:
A=online_add_degradation_v2(B)
### Remind: A is the input and B is corresponding GT
else:
#测试模式
#(B为清晰VOC数据集,A对应的含噪声的图像)
if self.opt.test_on_synthetic:
img_name_B,B=self.loaded_imgs[index]
A=online_add_degradation_v2(B)
img_name_A=img_name_B
path = os.path.join(self.load_img_dir, img_name_A)
else:
img_name_A,A=self.loaded_imgs[index]
img_name_B,B=self.loaded_imgs[index]
path = os.path.join(self.load_img_dir, img_name_A)
#去掉透明通道
if random.uniform(0,1)<0.1 and self.opt.isTrain:
A=A.convert("L")
B=B.convert("L")
A=A.convert("RGB")
B=B.convert("RGB")
## In P, we convert the RGB into L
##test on L
# split AB image into A and B
# w, h = img.size
# w2 = int(w / 2)
# A = img.crop((0, 0, w2, h))
# B = img.crop((w2, 0, w, h))
w,h=A.size
if w<256 or h<256:
A=transforms.Scale(256,Image.BICUBIC)(A)
B=transforms.Scale(256, Image.BICUBIC)(B)
# apply the same transform to both A and B
#获取变换相关参数
transform_params = get_params(self.opt, A.size)
#变换数据,数据增强
A_transform = get_transform(self.opt, transform_params)
B_transform = get_transform(self.opt, transform_params)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
B_tensor = B_transform(B)
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
if self.opt.isTrain:
return len(self.filtered_imgs_clean)
else:
return len(self.loaded_imgs)
def name(self):
return 'PairOldPhotos'
#成对带折痕图像载入器
class PairOldPhotos_with_hole(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.isImage = 'imagegan' in opt.name
self.task = 'old_photo_restoration_training_mapping'
self.dir_AB = opt.dataroot
#训练模式下,载入成对的带有裂痕的合成图片
if opt.isTrain:
self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
print("-------------Filter the imgs whose size <256 in VOC-------------")
#过滤出大小小于256的图片
self.filtered_imgs_clean = []
for i in range(len(self.loaded_imgs_clean)):
img_name, img = self.loaded_imgs_clean[i]
h, w = img.size
if h < 256 or w < 256:
continue
self.filtered_imgs_clean.append((img_name, img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
else:
self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
#载入不规则mask
self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)
self.pid = os.getpid()
def __getitem__(self, index):
if self.opt.isTrain:
img_name_clean,B = self.filtered_imgs_clean[index]
path = os.path.join(self.load_img_dir_clean, img_name_clean)
B=transforms.RandomCrop(256)(B)
A=online_add_degradation_v2(B)
### Remind: A is the input and B is corresponding GT
else:
img_name_A,A=self.loaded_imgs[index]
img_name_B,B=self.loaded_imgs[index]
path = os.path.join(self.load_img_dir, img_name_A)
#A=A.resize((256,256))
A=transforms.CenterCrop(256)(A)
B=A
if random.uniform(0,1)<0.1 and self.opt.isTrain:
A=A.convert("L")
B=B.convert("L")
A=A.convert("RGB")
B=B.convert("RGB")
## In P, we convert the RGB into L
if self.opt.isTrain:
#载入mask
mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]
else:
# 载入mask
mask_name, mask = self.loaded_masks[index%100]
#调整mask大小
mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)
if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:
mask=zero_mask(256)
if self.opt.no_hole:
mask=zero_mask(256)
#由mask合成带有折痕的图片
A,_=irregular_hole_synthesize(A,mask)
if not self.opt.isTrain and self.opt.hole_image_no_mask:
mask=zero_mask(256)
#获取做旧变换参数
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_transform = get_transform(self.opt, transform_params)
#对mask进行相同的左右翻转
if transform_params['flip'] and self.opt.isTrain:
mask=mask.transpose(Image.FLIP_LEFT_RIGHT)
#归一化
mask_tensor = transforms.ToTensor()(mask)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
B_tensor = B_transform(B)
input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
if self.opt.isTrain:
return len(self.filtered_imgs_clean)
else:
return len(self.loaded_imgs)
def name(self):
return 'PairOldPhotos_with_hole'