









提出PatchGAN的思路:简单来讲就是,D的输出不是一个scale(标量),而是一个矩阵Patch * Patch。然后来计算这个矩阵和real data(全一矩阵),以及fake data(全0矩阵)之间的距离(这里常用L2)。









import torch.utils.data as dataimport globimport osimport torchvision.transforms as transformsfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npimport torchimport piexifimport imghdrimport numberstry:    import accimageexcept ImportError:    accimage = Noneclass MyCrop(object):    """Crops the given PIL Image at the center.    Args:        size (sequence or int): Desired output size of the crop. If size is an            int instead of sequence like (h, w), a square crop (size, size) is            made.    """    def __init__(self, i, j, size):        if isinstance(size, numbers.Number):            self.size = (int(size), int(size))        else:            self.size = size        self.i, self.j = i, j    def __call__(self, img):        """        Args:            img (PIL Image): Image to be cropped.        Returns:            PIL Image: Cropped image.        """        if not isinstance(img, Image.Image):            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))        th, tw = self.size        return img.crop((self.j, self.i, self.j + tw, self.i + th))    def __repr__(self):        return self.__class__.__name__ + '(size={0})'.format(self.size)class MyDataset(data.Dataset):    def __init__(self, path_sketch, path_photo, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False,                 default=False):        self.Train = Train        self.sketch_dataset = self.init_dataset(path_sketch, Len=Len, resize=resize, img_type=img_type,                                                remove_exif=remove_exif, sketch=True, default=default)        self.photo_dataset = self.init_dataset(path_photo, Len=Len, resize=resize, img_type=img_type,                                               remove_exif=remove_exif, sketch=False, default=default)    def init_dataset(self, path, Len=-1, resize=-1, img_type='png', remove_exif=False, sketch=True, default=False):        if resize != -1:            if default:                transform = transforms.Compose([                    transforms.Resize(resize),                    transforms.CenterCrop(resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                ])            elif sketch:                transform = transforms.Compose([                    transforms.Resize(resize),                    MyCrop(30, 0, resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))                ])            else:                transform = transforms.Compose([                    transforms.Resize(resize + 20),                    MyCrop(15, 26, resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))                ])        else:            transform = transforms.Compose([                transforms.ToTensor(),            ])        img_format = '*.%s' % img_type        if remove_exif:            for name in glob.glob(os.path.join(path, img_format)):                try:                    piexif.remove(name)  # 去除exif                except Exception:                    continue        # imghdr.what(img_path) 判断是否为损坏图片        if Len == -1:            dataset = [np.array(transform(Image.open(name).convert("L"))) for name in                       glob.glob(os.path.join(path, img_format)) if imghdr.what(name)]        else:            dataset = [np.array(transform(Image.open(name).convert("L"))) for name in                       glob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]        dataset = np.array(dataset)        dataset = torch.Tensor(dataset)        return dataset    def __len__(self):        return len(self.photo_dataset)    def __getitem__(self, idx):        return self.sketch_dataset[idx], self.photo_dataset[idx]if __name__ == '__main__':    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=10, img_type='jpg')    print(len(dataset))    for i in range(5):        plt.imshow(np.squeeze(dataset[i][0].numpy()) * 0.5 + 0.5, cmap='gray')        plt.show()        print(dataset[i][0].max(), dataset[i][0].min())        plt.imshow(np.squeeze(dataset[i][1].numpy()) * 0.5 + 0.5, cmap='gray')        plt.show()        print(dataset[i][1].max(), dataset[i][1].min())


import osimport torchfrom torch.utils.data import Dataset, DataLoaderimport torch.nn as nnfrom model import Generator, Discriminator, gp_loss# from model import gp_loss# from github_model import Generator, Discriminatorimport torchvisionfrom dataloader import MyDatasetimport matplotlib.pyplot as pltimport itertoolsimport numpy as npimport torchvision.utils as vutilsif __name__ == '__main__':    LR = 0.0002    EPOCH = 100  # 50    BATCH_SIZE = 4    drop_rate = 0.5    lam = 10    TRAINED = False    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=88, img_type='jpg')    data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)    torch.cuda.empty_cache()    if not TRAINED:        G = Generator(1, drop_rate=drop_rate).cuda()        D = Discriminator(1).cuda()    else:        G = torch.load("G.pkl").cuda()        D = torch.load("D.pkl").cuda()    optimizerG = torch.optim.Adam(G.parameters(), lr=LR)    optimizerD = torch.optim.Adam(D.parameters(), lr=LR)    l1_c = nn.L1Loss()    mse_c = nn.MSELoss()    # PATCH SHAPE IS (1, 12, 12)    real_label = torch.ones((BATCH_SIZE, 1, 12, 12)).cuda()    fake_label = torch.zeros((BATCH_SIZE, 1, 12, 12)).cuda()    for epoch in range(EPOCH):        tmpD, tmpG = 0, 0        for step, (x, y) in enumerate(data_loader):            x = x.cuda()            y = y.cuda()            G_x = G(y)            D_xy = D(x, y)  # PatchGAN            D_gxy = D(G_x, y)            # print(D_xy.shape, D_gxy.shape)            D_loss = mse_c(D_xy, real_label) + mse_c(D_gxy, fake_label)            G_loss = mse_c(D_gxy, real_label) + lam * l1_c(G_x, x)            optimizerG.zero_grad()            G_loss.backward(retain_graph=True)            optimizerG.step()            tmpD_ = D_loss.cpu().detach().data            tmpG_ = G_loss.cpu().detach().data            tmpD += tmpD_            tmpG += tmpG_        tmpD /= (step + 1)        tmpG /= (step + 1)        print(            'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG)        )        if (epoch + 1) % 5 == 0:            fig = plt.figure(figsize=(10, 10))            plt.axis("off")            plt.imshow(np.transpose(                vutils.make_grid(torch.stack([G_x[0].cpu().detach(), x[0].cpu().detach(), y[0].cpu().detach(), ]),                                 nrow=3, padding=0, normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')            plt.show()    torch.save(G, 'G.pkl')    torch.save(D, 'D.pkl')


import osimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderfrom dataloader import MyDatasetimport torch.autograd as autogradclass ResidualBlock(nn.Module):    def __init__(self, in_channel=1, out_channel=1, stride=1):        super(ResidualBlock, self).__init__()        self.weight_layer = nn.Sequential(            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1),            nn.BatchNorm2d(out_channel),            nn.ReLU(),            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1),        )        self.active_layer = nn.Sequential(            nn.BatchNorm2d(out_channel),            nn.ReLU()        )    def forward(self, x):        residual = x        x = self.weight_layer(x)        x += residual        return self.active_layer(x)class Generator(nn.Module):    def __init__(self, input_channel=1, drop_rate=0.5):        super(Generator, self).__init__()        self.c_e1 = nn.Sequential(            nn.Conv2d(in_channels=input_channel, out_channels=64, kernel_size=4, stride=2, padding=1),            nn.LeakyReLU(0.2),            ResidualBlock(in_channel=64, out_channel=64))        self.c_e2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2),                                  nn.BatchNorm2d(128), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=128, out_channel=128))        self.c_e3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=256, out_channel=256), nn.Dropout2d(drop_rate))        self.c_e4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=256, out_channel=256), nn.Dropout2d(drop_rate))        self.d_e1 = nn.Sequential(            nn.ConvTranspose2d(in_channels=128, out_channels=input_channel, kernel_size=4, stride=2, padding=1),            nn.Tanh())        self.d_e2 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=4, stride=2),                                  nn.BatchNorm2d(64), nn.ReLU(), )        self.d_e3 = nn.Sequential(nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=5, stride=2),                                  nn.BatchNorm2d(128), nn.Dropout2d(drop_rate))        self.d_e4 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.Dropout2d(drop_rate))    def forward(self, x):        e1 = self.c_e1(x)        e2 = self.c_e2(e1)        e3 = self.c_e3(e2)        e4 = self.c_e4(e3)        d4 = self.d_e4(e4)        # print(d4.shape, e3.shape)        d4 = torch.cat([d4, e3], dim=1)        # d4 = d4 + e3        d3 = self.d_e3(d4)        # print(d3.shape, e2.shape)        d3 = torch.cat([d3, e2], dim=1)        # d3 = d3 + e2        d2 = self.d_e2(d3)        # print(d2.shape, e1.shape)        d2 = torch.cat([d2, e1], dim=1)        # d2 = d2 + e1        # print(d2.shape)        d1 = self.d_e1(d2)        # print(d1.shape)        return d1class Discriminator(nn.Module):    def __init__(self, input_size):        super(Discriminator, self).__init__()        strides = [1, 2, 2, 2]        padding = [0, 1, 1, 1]        channels = [input_size * 2,                    64, 128, 256, 1]  # 1表示一维        kernels = [4, 4, 4, 3]        model = []        for i, stride in enumerate(strides):            model.append(                nn.Conv2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            model.append(nn.BatchNorm2d(channels[i + 1]))            model.append(                nn.LeakyReLU(0.2)            )        self.main = nn.Sequential(*model)    def forward(self, fake_x, real_x):        x = torch.cat([fake_x, real_x], dim=1)        x = self.main(x)        return x  # .view(x.shape[0], -1)        # return self.fc(x)def gp_loss(D, real_x, fake_x, cuda=False):    if cuda:        alpha = torch.rand((real_x.shape[0], 1, 1, 1)).cuda()    else:        alpha = torch.rand((real_x.shape[0], 1, 1, 1))    x_ = (alpha * real_x + (1 - alpha) * fake_x).requires_grad_(True)    y_ = D(x_)    # cal f'(x)    grad = autograd.grad(        outputs=y_,        inputs=x_,        grad_outputs=torch.ones_like(y_),        create_graph=True,        retain_graph=True,        only_inputs=True,    )[0]    grad = grad.view(x_.shape[0], -1)    gp = ((grad.norm(2, dim=1) - 1) ** 2).mean()    return gpif __name__ == '__main__':    drop_rate = 0.5    G = Generator(1, drop_rate)    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=10, img_type='jpg')    train_loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)    D = Discriminator(1)    rs = ResidualBlock(1, 1, stride=1)  # only for stride 1    for step, (x, y) in enumerate(train_loader):        print(x.shape)        print(G(x).shape)        print(D(x, x).shape)        print(rs(x).shape)        break


import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorfrom dataloader import MyDatasetfrom torch.utils.data import Dataset, DataLoaderimport itertoolsimport torchvision.utils as vutilsif __name__ == '__main__':    BATCH_SIZE = 3    TIMES = 5    img_shape = (1, 28, 28)    G = torch.load("G.pkl").cuda()    D = torch.load("D.pkl").cuda()    sketch_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    photo_path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path_sketch=sketch_path, path_photo=photo_path, resize=96, Len=BATCH_SIZE * TIMES, img_type='jpg')    data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)    for step, (x, y) in enumerate(data_loader):        x = x.cuda()        y = y.cuda()        G_x = G(y)        fig = plt.figure(figsize=(10, 10))        plt.axis("off")        plt.imshow(np.transpose(            vutils.make_grid(                torch.stack([G_x.cpu().detach(), x.cpu().detach(), y.cpu().detach()]).transpose(1, 0).contiguous().view(                    BATCH_SIZE * 3, 1, 96, 96), nrow=3, padding=0,                normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')        plt.savefig(str(step) + '.png', dpi=300)        plt.show()

