http://arxiv.org/abs/1611.07004
2017年的一篇CVPR。是非常经典的一个模型。
pix2pix是基于Conditional-GAN,也就是CGAN。相比于一般的CGAN输入为一个较为常见的label(one-hot编码的标签)。这里将输入控制为一个图片。
CGAN的模型思路如下:
如果是图片作为输入的话,其实要求就会高了很多了。
同时,pix2pix也是之前提到的DualGAN,还有还没有提到的CycleGAN这些模型的基石。
不同于后续的模型,在要求上更加宽松,不需要成对的数据,pix2pix其实对于数据集做了要求的,必须是成对的数据来用于训练。
pix2pix的主要贡献:
提出PatchGAN的思路:简单来讲就是,D的输出不是一个scale(标量),而是一个矩阵Patch * Patch。然后来计算这个矩阵和real data(全一矩阵),以及fake data(全0矩阵)之间的距离(这里常用L2)。
为了捕捉高频的信息(这里使用PatchGAN的模型);低频的信息用L1norm来保证。
使用L1范数,而不是L2范数:这里是指衡量生成数据和真实数据之间的距离的时候给G添加的一个损失。这个损失的距离计算方式不是我们常用的L2范数,而是L1范数,目的就是为了捕获低频的信息。(使用L1的模糊程度会小很多)
不用z做G的输入,而是添加Dropout:这个也是DualGAN在这学。实验结果显示,这样效果更加好。
G使用U-Net结构而不是Encoder-Decoder结构:DualGAN关于G的设计就是学这个。也就是需要把encoder的信息concat到对称的Decoder的部分。避免低维的信息在计算的过程中消失掉,使得能更好的保存图像的原始特征。(有点像风格迁移的时候,需要保存初始图像该有的信息)
恰饭
相关阅读
CGAN模型理论以及Python实现
DualGAN模型理论以及Python实现
实验
实验部分基本上是在DualGAN的代码上改改。但是实际上,是DualGAN学习pix2pix。
第一列是,模型生成的素描图;
第二列是,真实数据中对应的素描图;
第三列是,真实数据中的输入图(实拍照片)。
dataloader.py
import torch.utils.data as dataimport globimport osimport torchvision.transforms as transformsfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npimport torchimport piexifimport imghdrimport numberstry: import accimageexcept ImportError: accimage = Noneclass MyCrop(object): """Crops the given PIL Image at the center. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. """ def __init__(self, i, j, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.i, self.j = i, j def __call__(self, img): """ Args: img (PIL Image): Image to be cropped. Returns: PIL Image: Cropped image. """ if not isinstance(img, Image.Image): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) th, tw = self.size return img.crop((self.j, self.i, self.j + tw, self.i + th)) def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size)class MyDataset(data.Dataset): def __init__(self, path_sketch, path_photo, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False, default=False): self.Train = Train self.sketch_dataset = self.init_dataset(path_sketch, Len=Len, resize=resize, img_type=img_type, remove_exif=remove_exif, sketch=True, default=default) self.photo_dataset = self.init_dataset(path_photo, Len=Len, resize=resize, img_type=img_type, remove_exif=remove_exif, sketch=False, default=default) def init_dataset(self, path, Len=-1, resize=-1, img_type='png', remove_exif=False, sketch=True, default=False): if resize != -1: if default: transform = transforms.Compose([ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) elif sketch: transform = transforms.Compose([ transforms.Resize(resize), MyCrop(30, 0, resize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) else: transform = transforms.Compose([ transforms.Resize(resize + 20), MyCrop(15, 26, resize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) else: transform = transforms.Compose([ transforms.ToTensor(), ]) img_format = '*.%s' % img_type if remove_exif: for name in glob.glob(os.path.join(path, img_format)): try: piexif.remove(name) # 去除exif except Exception: continue # imghdr.what(img_path) 判断是否为损坏图片 if Len == -1: dataset = [np.array(transform(Image.open(name).convert("L"))) for name in glob.glob(os.path.join(path, img_format)) if imghdr.what(name)] else: dataset = [np.array(transform(Image.open(name).convert("L"))) for name in glob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)] dataset = np.array(dataset) dataset = torch.Tensor(dataset) return dataset def __len__(self): return len(self.photo_dataset) def __getitem__(self, idx): return self.sketch_dataset[idx], self.photo_dataset[idx]if __name__ == '__main__': sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch' photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo' dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=10, img_type='jpg') print(len(dataset)) for i in range(5): plt.imshow(np.squeeze(dataset[i][0].numpy()) * 0.5 + 0.5, cmap='gray') plt.show() print(dataset[i][0].max(), dataset[i][0].min()) plt.imshow(np.squeeze(dataset[i][1].numpy()) * 0.5 + 0.5, cmap='gray') plt.show() print(dataset[i][1].max(), dataset[i][1].min())
main.py
import osimport torchfrom torch.utils.data import Dataset, DataLoaderimport torch.nn as nnfrom model import Generator, Discriminator, gp_loss# from model import gp_loss# from github_model import Generator, Discriminatorimport torchvisionfrom dataloader import MyDatasetimport matplotlib.pyplot as pltimport itertoolsimport numpy as npimport torchvision.utils as vutilsif __name__ == '__main__': LR = 0.0002 EPOCH = 100 # 50 BATCH_SIZE = 4 drop_rate = 0.5 lam = 10 TRAINED = False sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch' photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo' dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=88, img_type='jpg') data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) torch.cuda.empty_cache() if not TRAINED: G = Generator(1, drop_rate=drop_rate).cuda() D = Discriminator(1).cuda() else: G = torch.load("G.pkl").cuda() D = torch.load("D.pkl").cuda() optimizerG = torch.optim.Adam(G.parameters(), lr=LR) optimizerD = torch.optim.Adam(D.parameters(), lr=LR) l1_c = nn.L1Loss() mse_c = nn.MSELoss() # PATCH SHAPE IS (1, 12, 12) real_label = torch.ones((BATCH_SIZE, 1, 12, 12)).cuda() fake_label = torch.zeros((BATCH_SIZE, 1, 12, 12)).cuda() for epoch in range(EPOCH): tmpD, tmpG = 0, 0 for step, (x, y) in enumerate(data_loader): x = x.cuda() y = y.cuda() G_x = G(y) D_xy = D(x, y) # PatchGAN D_gxy = D(G_x, y) # print(D_xy.shape, D_gxy.shape) D_loss = mse_c(D_xy, real_label) + mse_c(D_gxy, fake_label) G_loss = mse_c(D_gxy, real_label) + lam * l1_c(G_x, x) optimizerG.zero_grad() G_loss.backward(retain_graph=True) optimizerG.step() tmpD_ = D_loss.cpu().detach().data tmpG_ = G_loss.cpu().detach().data tmpD += tmpD_ tmpG += tmpG_ tmpD /= (step + 1) tmpG /= (step + 1) print( 'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG) ) if (epoch + 1) % 5 == 0: fig = plt.figure(figsize=(10, 10)) plt.axis("off") plt.imshow(np.transpose( vutils.make_grid(torch.stack([G_x[0].cpu().detach(), x[0].cpu().detach(), y[0].cpu().detach(), ]), nrow=3, padding=0, normalize=True, scale_each=True), (1, 2, 0)), cmap='gray') plt.show() torch.save(G, 'G.pkl') torch.save(D, 'D.pkl')
model.py
import osimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderfrom dataloader import MyDatasetimport torch.autograd as autogradclass ResidualBlock(nn.Module): def __init__(self, in_channel=1, out_channel=1, stride=1): super(ResidualBlock, self).__init__() self.weight_layer = nn.Sequential( nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(), nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1), ) self.active_layer = nn.Sequential( nn.BatchNorm2d(out_channel), nn.ReLU() ) def forward(self, x): residual = x x = self.weight_layer(x) x += residual return self.active_layer(x)class Generator(nn.Module): def __init__(self, input_channel=1, drop_rate=0.5): super(Generator, self).__init__() self.c_e1 = nn.Sequential( nn.Conv2d(in_channels=input_channel, out_channels=64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2), ResidualBlock(in_channel=64, out_channel=64)) self.c_e2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), ResidualBlock(in_channel=128, out_channel=128)) self.c_e3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), ResidualBlock(in_channel=256, out_channel=256), nn.Dropout2d(drop_rate)) self.c_e4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), ResidualBlock(in_channel=256, out_channel=256), nn.Dropout2d(drop_rate)) self.d_e1 = nn.Sequential( nn.ConvTranspose2d(in_channels=128, out_channels=input_channel, kernel_size=4, stride=2, padding=1), nn.Tanh()) self.d_e2 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=4, stride=2), nn.BatchNorm2d(64), nn.ReLU(), ) self.d_e3 = nn.Sequential(nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=5, stride=2), nn.BatchNorm2d(128), nn.Dropout2d(drop_rate)) self.d_e4 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2), nn.BatchNorm2d(256), nn.Dropout2d(drop_rate)) def forward(self, x): e1 = self.c_e1(x) e2 = self.c_e2(e1) e3 = self.c_e3(e2) e4 = self.c_e4(e3) d4 = self.d_e4(e4) # print(d4.shape, e3.shape) d4 = torch.cat([d4, e3], dim=1) # d4 = d4 + e3 d3 = self.d_e3(d4) # print(d3.shape, e2.shape) d3 = torch.cat([d3, e2], dim=1) # d3 = d3 + e2 d2 = self.d_e2(d3) # print(d2.shape, e1.shape) d2 = torch.cat([d2, e1], dim=1) # d2 = d2 + e1 # print(d2.shape) d1 = self.d_e1(d2) # print(d1.shape) return d1class Discriminator(nn.Module): def __init__(self, input_size): super(Discriminator, self).__init__() strides = [1, 2, 2, 2] padding = [0, 1, 1, 1] channels = [input_size * 2, 64, 128, 256, 1] # 1表示一维 kernels = [4, 4, 4, 3] model = [] for i, stride in enumerate(strides): model.append( nn.Conv2d( in_channels=channels[i], out_channels=channels[i + 1], stride=stride, kernel_size=kernels[i], padding=padding[i] ) ) model.append(nn.BatchNorm2d(channels[i + 1])) model.append( nn.LeakyReLU(0.2) ) self.main = nn.Sequential(*model) def forward(self, fake_x, real_x): x = torch.cat([fake_x, real_x], dim=1) x = self.main(x) return x # .view(x.shape[0], -1) # return self.fc(x)def gp_loss(D, real_x, fake_x, cuda=False): if cuda: alpha = torch.rand((real_x.shape[0], 1, 1, 1)).cuda() else: alpha = torch.rand((real_x.shape[0], 1, 1, 1)) x_ = (alpha * real_x + (1 - alpha) * fake_x).requires_grad_(True) y_ = D(x_) # cal f'(x) grad = autograd.grad( outputs=y_, inputs=x_, grad_outputs=torch.ones_like(y_), create_graph=True, retain_graph=True, only_inputs=True, )[0] grad = grad.view(x_.shape[0], -1) gp = ((grad.norm(2, dim=1) - 1) ** 2).mean() return gpif __name__ == '__main__': drop_rate = 0.5 G = Generator(1, drop_rate) sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch' photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo' dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=10, img_type='jpg') train_loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True) D = Discriminator(1) rs = ResidualBlock(1, 1, stride=1) # only for stride 1 for step, (x, y) in enumerate(train_loader): print(x.shape) print(G(x).shape) print(D(x, x).shape) print(rs(x).shape) break
judge.py
import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorfrom dataloader import MyDatasetfrom torch.utils.data import Dataset, DataLoaderimport itertoolsimport torchvision.utils as vutilsif __name__ == '__main__': BATCH_SIZE = 3 TIMES = 5 img_shape = (1, 28, 28) G = torch.load("G.pkl").cuda() D = torch.load("D.pkl").cuda() sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch' photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo' dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=BATCH_SIZE * TIMES, img_type='jpg') data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) for step, (x, y) in enumerate(data_loader): x = x.cuda() y = y.cuda() G_x = G(y) fig = plt.figure(figsize=(10, 10)) plt.axis("off") plt.imshow(np.transpose( vutils.make_grid( torch.stack([G_x.cpu().detach(), x.cpu().detach(), y.cpu().detach()]).transpose(1, 0).contiguous().view( BATCH_SIZE * 3, 1, 96, 96), nrow=3, padding=0, normalize=True, scale_each=True), (1, 2, 0)), cmap='gray') plt.savefig(str(step) + '.png', dpi=300) plt.show()