(1)Bringing Old Photos Back to Life原理和测试
(2)
Bringing Old Photos Back to Life模型代码分析1(数据载入部分)
Bringing Old Photos Back to Life模型代码分析2(模型部分)
(3)Bringing Old Photos Back to Life数据集及其训练
这是代码的模型部分,如下:
其中一些注释参考了网上的文章。
base_model.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import torch
import sys
## 模型基类
class BaseModel(torch.nn.Module): ## 继承自torch.nn.Module
def name(self):
return "BaseModel"
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) ## 保存路径:'...'
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label): # 保存标签
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename) # # 保存路径
torch.save(network.cpu().state_dict(), save_path) # 保存模型参数
if len(gpu_ids) and torch.cuda.is_available(): # 如果有gpu,使用gpu
network.cuda()
def save_optimizer(self, optimizer, optimizer_label, epoch_label):
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) # # 保存文件名
save_path = os.path.join(self.save_dir, save_filename)
torch.save(optimizer.state_dict(), save_path)
def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""):
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename) # # 保存路径
if not os.path.isfile(save_path):
print("%s not exists yet!" % save_path)
else:
optimizer.load_state_dict(torch.load(save_path))
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=""):
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
# print(save_dir)
# print(self.save_dir)
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print("%s not exists yet!" % save_path)
# if network_label == 'G':
# raise('Generator must exist!')
else:
# network.load_state_dict(torch.load(save_path))
try: # # 正常运行
# print(save_path)
network.load_state_dict(torch.load(save_path))
except: # 异常处理
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
# if self.opt.verbose:
print(
"Pretrained network %s has excessive layers; Only loading layers that are used"
% network_label
)
except:
print(
"Pretrained network %s has fewer layers; The following are not initialized:"
% network_label
)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
#查看版本
if sys.version_info >= (3, 0):
#set() 函数创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可以计算交集、差集、并集等。
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split(".")[0])
print(sorted(not_initialized))
#模型加载
network.load_state_dict(model_dict)
def update_learning_rate():
pass
models.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
#创建模型,并返回模型
def create_model(opt):
## 选择pix2pixHD model
if opt.model == "pix2pixHD":
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
## 若是训练,则为True
if opt.isTrain:
model = Pix2PixHDModel()
# # 否则,若仅仅是前向传播用来演示,则为False
else:
model = InferenceModel()
# 选择 UIModel model
else:
from .ui_model import UIModel
model = UIModel()
# 模型初始化参数
model.initialize(opt)
# 默认为false,表示之前并无模型保存
if opt.verbose:
# 打印label2city模型被创建
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids) > 1:
# pass
## 多GPU训练
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
return model
def create_da_model(opt):
## 选择pix2pixHD model
if opt.model == 'pix2pixHD':
from .pix2pixHD_model_DA import Pix2PixHDModel, InferenceModel
## 若是训练,则为True
if opt.isTrain:
model = Pix2PixHDModel()
# # 否则,若仅仅是前向传播用来演示,则为False
else:
model = InferenceModel()
## 选择 UIModel model
else:
from .ui_model import UIModel
model = UIModel()
## 模型初始化参数
model.initialize(opt)
# 默认为false,表示之前并无模型保存
if opt.verbose:
# 打印 模型被创建
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids) > 1:
#pass
## 多GPU训练
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
return model
networks.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
from torch.nn.utils import spectral_norm
# from util.util import SwitchNorm2d
import torch.nn.functional as F
###############################################################################
# Functions
###############################################################################
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
#
def get_norm_layer(norm_type="instance"):
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == "spectral":
norm_layer = spectral_norm()
elif norm_type == "SwitchNorm":
norm_layer = SwitchNorm2d
else:
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
return norm_layer
#打印网络和参数量
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print("Total number of parameters: %d" % num_params)
# input_nc = 3
# output_nc = 3
# ngf = 64 第一层卷积核数
def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[], opt=None):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
# if opt.self_gen:
if opt.use_v2:
netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt)
else:
netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt)
else:
raise('generator not implemented!')
print(netG) ## 打印生成网络
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
## 在define_D中,主要内容是下面这行。
netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD) # 打印判别网络
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init)
return netD
#
class GlobalGenerator_DCDCv2(nn.Module):
def __init__(
self,
input_nc,
output_nc,
ngf=64,
k_size=3,
n_downsampling=8, #########这里下采样次数为8, 原pix2pixHD中为3
norm_layer=nn.BatchNorm2d,
padding_type="reflect",
opt=None,
):
super(GlobalGenerator_DCDCv2, self).__init__()
activation = nn.ReLU(True)
# 先定义第一层,用的是zero_padding,因为第一层用的是7x7的卷积核####################(改进),padding=0,卷积后分辨率降低了3;
# 因此再镜像填充ReflectionPad2d(3)
# [3,512,512]->[64,512,512]
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0 ),
norm_layer(ngf),
activation,
]
# 之后就是下采样环节,每一层卷积的stride都是2
### downsample
# [64,512,512]->[128,256,256]->[256,128,128]->[512,64,64]....
#start_r表示start layer to use resblock,
for i in range(opt.start_r):
mult = 2 ** i
model += [
nn.Conv2d(
min(ngf * mult, opt.mc),
min(ngf * mult * 2, opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
),
norm_layer(min(ngf * mult * 2, opt.mc)),
activation,
]
#下采样和RES 混合(RES不改变大小)
for i in range(opt.start_r, n_downsampling - 1):
mult = 2 ** i
model += [
nn.Conv2d(
min(ngf * mult, opt.mc),
min(ngf * mult * 2, opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
),
norm_layer(min(ngf * mult * 2, opt.mc)),
activation,
]
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
mult = 2 ** (n_downsampling - 1)
if opt.spatio_size == 32:
model += [
nn.Conv2d(
min(ngf * mult, opt.mc),
min(ngf * mult * 2, opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
),
norm_layer(min(ngf * mult * 2, opt.mc)),
activation,
]
if opt.spatio_size == 64:
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
# model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)]
if opt.feat_dim > 0:
model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)]
self.encoder = nn.Sequential(*model)
# ############ decode 解码器。上采样
model = []
if opt.feat_dim > 0:
model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)]
# model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)]
o_pad = 0 if k_size == 4 else 1
mult = 2 ** n_downsampling
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
if opt.spatio_size == 32:
model += [
nn.ConvTranspose2d(
min(ngf * mult, opt.mc),
min(int(ngf * mult / 2), opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
output_padding=o_pad,
),
norm_layer(min(int(ngf * mult / 2), opt.mc)),
activation,
]
if opt.spatio_size == 64:
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
#逆卷积和残差网络
for i in range(1, n_downsampling - opt.start_r):
mult = 2 ** (n_downsampling - i)
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
nn.ConvTranspose2d(
min(ngf * mult, opt.mc),
min(int(ngf * mult / 2), opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
output_padding=o_pad,
),
norm_layer(min(int(ngf * mult / 2), opt.mc)),
activation,
]
for i in range(n_downsampling - opt.start_r, n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
nn.ConvTranspose2d(
min(ngf * mult, opt.mc),
min(int(ngf * mult / 2), opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
output_padding=o_pad,
),
norm_layer(min(int(ngf * mult / 2), opt.mc)),
activation,
]
if opt.use_segmentation_model:
model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)]
else:
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0),
nn.Tanh(),
]
self.decoder = nn.Sequential(*model)
def forward(self, input, flow="enc_dec"):
if flow == "enc":
return self.encoder(input)
elif flow == "dec":
return self.decoder(input)
elif flow == "enc_dec":
x = self.encoder(input)
x = self.decoder(x)
return x
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(
self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1
):
super(ResnetBlock, self).__init__()
self.opt = opt
self.dilation = dilation
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(self.dilation)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(self.dilation)]
elif padding_type == "zero":
p = self.dilation
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation),
norm_layer(dim),
activation,
]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
#
class Encoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
super(Encoder, self).__init__()
self.output_nc = output_nc
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True),
]
### downsample #四次下采样
for i in range(n_downsampling):
mult = 2 ** i
model += [
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True),
]
### upsample ##上采样,逆卷积
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
nn.ConvTranspose2d(
ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1
),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input, inst):
outputs = self.model(input)
# instance-wise average pooling
outputs_mean = outputs.clone()
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4
for j in range(self.output_nc):
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]]
mean_feat = torch.mean(output_ins).expand_as(output_ins)
outputs_mean[
indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]
] = mean_feat
return outputs_mean
#
def SN(module, mode=True):
if mode:
return torch.nn.utils.spectral_norm(module)
return module
########## 带mask的model
class NonLocalBlock2D_with_mask_Res(nn.Module):
def __init__(
self,
in_channels,
inter_channels,
mode="add",
re_norm=False,
temperature=1.0,
use_self=False,
cosin=False,
):
super(NonLocalBlock2D_with_mask_Res, self).__init__()
self.cosin = cosin
self.renorm = re_norm
self.in_channels = in_channels
self.inter_channels = inter_channels
self.g = nn.Conv2d(
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
)
self.W = nn.Conv2d(
in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0
)
# for pytorch 0.3.1
# nn.init.constant(self.W.weight, 0)
# nn.init.constant(self.W.bias, 0)
# for pytorch 0.4.0
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = nn.Conv2d(
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
)
self.phi = nn.Conv2d(
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
)
self.mode = mode
self.temperature = temperature
self.use_self = use_self
norm_layer = get_norm_layer(norm_type="instance")
activation = nn.ReLU(True)
model = []
for i in range(3):
model += [
ResnetBlock(
inter_channels,
padding_type="reflect",
activation=activation,
norm_layer=norm_layer,
opt=None,
)
]
self.res_block = nn.Sequential(*model)
def forward(self, x, mask): ## The shape of mask is Batch*1*H*W
batch_size = x.size(0)
#view()函数作用是将一个多行的Tensor,拼接成一行。
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
#调整通道
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
#permute 将tensor的维度换位。
theta_x = theta_x.permute(0, 2, 1)
#view()函数作用是将一个多行的Tensor,拼接成一行。
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
if self.cosin:
theta_x = F.normalize(theta_x, dim=2)
phi_x = F.normalize(phi_x, dim=1)
#乘
f = torch.matmul(theta_x, phi_x)
f /= self.temperature
f_div_C = F.softmax(f, dim=2)
tmp = 1 - mask
#interpolate是用插值来上采样或下采样
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
mask[mask > 0] = 1.0
mask = 1 - mask
tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
mask *= tmp
#调整通道
mask_expand = mask.view(batch_size, 1, -1)
#repeat()函数可以对张量进行复制。
#当参数只有两个时,第一个参数表示的是复制后的列数,第二个参数表示复制后的行数。
#当参数有三个时,第一个参数表示的是复制后的通道数,第二个参数表示的是复制后的列数,第三个参数表示复制后的行数。
mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1)
# mask = 1 - mask
# mask=F.interpolate(mask,(x.size(2),x.size(3)))
# mask_expand=mask.view(batch_size,1,-1)
# mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1)
if self.use_self:
mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0
# print(mask_expand.shape)
# print(f_div_C.shape)
f_div_C = mask_expand * f_div_C
if self.renorm:
f_div_C = F.normalize(f_div_C, p=1, dim=2)
###########################
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
W_y = self.res_block(W_y)
if self.mode == "combine":
full_mask = mask.repeat(1, self.inter_channels, 1, 1)
z = full_mask * x + (1 - full_mask) * W_y
return z
###########################################################################################################
# 多尺度判别器
###########################################################################################################
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D # # 判别器的数量
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
# 生成的NLayerDiscriminator类,被设置为当前类(self)的一个属性。生成num_D个D。
# 跳到NLayerDiscriminator看看。
netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat: # # 忽略这条分支
for j in range(n_layers+2):
### setattr() 函数对应函数 getattr(),用于设置属性值,该属性不一定是存在的。
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) ## 比如scale1_layer4属性赋值
else:
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) # 下采样
#
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
# 主要看下D的前向传播过程
def forward(self, input):
num_D = self.num_D # # 判别器D的数量
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
else:
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled)) # 两个结果大小是2倍关系
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result
# Defines the PatchGAN discriminator with the specified arguments.
## 用指定的参数定义PatchGAN鉴别器
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
# np.ceil(ndarray)计算大于等于该值的最小整数
# >>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])
# >>> np.ceil(a)
# array([-1., -1., -0., 1., 2., 2., 2.])
padw = int(np.ceil((kw-1.0)/2)) # # padw = 2
# [3, 512, 512] -> [64, 257, 257]
sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]]
nf = ndf
# [64, 257, 257] -> [128, 129, 129] -> [256, 65, 65]
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512) ## nf = 128, 256
sequence += [[
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), ## nf_prev = 64, 128
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
# [256, 65, 65] -> [64, 32, 32]
sequence += [[
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
# [64, 32, 32] -> [1, 16, 16]
sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]]
if use_sigmoid: ## use_sigmiod是False,即D的输出不使用sigmiod
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
## 至此,G和D就定义完成了,
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n)) ## 定义的model
res.append(model(res[-1]))
return res[1:] # # 返回model
else:
return self.model(input) # 返回model
class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask
def __init__(self, in_channels, inter_channels, patch_size):
super(Patch_Attention_4, self).__init__()
self.patch_size=patch_size
# self.g = nn.Conv2d(
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
# )
# self.W = nn.Conv2d(
# in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0
# )
# # for pytorch 0.3.1
# # nn.init.constant(self.W.weight, 0)
# # nn.init.constant(self.W.bias, 0)
# # for pytorch 0.4.0
# nn.init.constant_(self.W.weight, 0)
# nn.init.constant_(self.W.bias, 0)
# self.theta = nn.Conv2d(
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
# )
# self.phi = nn.Conv2d(
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
# )
self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True)
norm_layer = get_norm_layer(norm_type="instance")
activation = nn.ReLU(True)
model = []
for i in range(1):
model += [
ResnetBlock(
inter_channels,
padding_type="reflect",
activation=activation,
norm_layer=norm_layer,
opt=None,
)
]
self.res_block = nn.Sequential(*model)
def Hard_Compose(self, input, dim, index):
# batch index select
# input: [B,C,HW]
# dim: scalar > 0
# index: [B, HW]
views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))]
expanse = list(input.size())
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
def forward(self, z, mask): ## The shape of mask is Batch*1*H*W
x=self.res_block(z)
b,c,h,w=x.shape
## mask resize + dilation
# tmp = 1 - mask
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
mask[mask > 0] = 1.0
# mask = 1 - mask
# tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
# mask *= tmp
# mask=1-mask
## 1: mask position 0: non-mask
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()
all_patch_num=h*w/self.patch_size/self.patch_size
non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1)
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
y_unfold=x_unfold.permute(0,2,1)
x_unfold_normalized=F.normalize(x_unfold,dim=1)
y_unfold_normalized=F.normalize(y_unfold,dim=2)
correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized)
correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9)
correlation_matrix=F.softmax(correlation_matrix,dim=2)
# print(correlation_matrix)
R, max_arg=torch.max(correlation_matrix,dim=2)
composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg)
composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)
concat_1=torch.cat((z,composed_fold,mask),dim=1)
concat_1=self.F_Combine(concat_1)
return concat_1
def inference_forward(self,z,mask): ## Reduce the extra memory cost
x=self.res_block(z)
b,c,h,w=x.shape
## mask resize + dilation
# tmp = 1 - mask
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
mask[mask > 0] = 1.0
# mask = 1 - mask
# tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
# mask *= tmp
# mask=1-mask
## 1: mask position 0: non-mask
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num
all_patch_num=h*w/self.patch_size/self.patch_size
mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0]
if len(mask_index)==0: ## No mask patch is selected, no attention is needed
composed_fold=x
else:
unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0]
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
Query_Patch=torch.index_select(x_unfold,2,mask_index)
Key_Patch=torch.index_select(x_unfold,2,unmask_index)
Query_Patch=Query_Patch.permute(0,2,1)
Query_Patch_normalized=F.normalize(Query_Patch,dim=2)
Key_Patch_normalized=F.normalize(Key_Patch,dim=1)
correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized)
correlation_matrix=F.softmax(correlation_matrix,dim=2)
R, max_arg=torch.max(correlation_matrix,dim=2)
composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg)
x_unfold[:,:,mask_index]=composed_unfold
composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)
concat_1=torch.cat((z,composed_fold,mask),dim=1)
concat_1=self.F_Combine(concat_1)
return concat_1
##############################################################################
# Losses
##############################################################################
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
####################################### VGG Loss###########################
from torchvision import models
class VGG19_torch(torch.nn.Module):
def __init__(self, requires_grad=False):
super(VGG19_torch, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class VGGLoss_torch(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss_torch, self).__init__()
self.vgg = VGG19_torch().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) #切断一些分支的反向传播
return loss
pix2pixHD_model.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
################################################################################################
# 看到Pix2PixHDModel的类。这个类的内容非常多,有搭建模型,定义优化器和损失函数,导入模型等操作。
################################################################################################
class Pix2PixHDModel(BaseModel): # # 继承自BaseModel类,里面主要有save和load模型函数,BaseModel类继承自torch.nn.module
def name(self):
return 'Pix2PixHDModel'
# loss过滤器:其中g_gan、d_real、d_fake三个loss值是肯定返回的
# 至于g_gan_feat,g_vgg两个loss值根据train_options的opt.no_ganFeat_loss, not opt.no_vgg_loss而定
# 备注:这个函数只是一个过滤器,不仅可以滤掉loss值,也可以滤掉loss name,主要看是谁在调用,输入什么,就可以滤掉什么。
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_L1):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True,use_smooth_L1)
def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake,smooth_l1):
### zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
# 如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。
# >>>a = [1,2,3]
# >>> b = [4,5,6]
# >>> c = [4,5,6,7,8]
# >>> zipped = zip(a,b) # 打包为元组的列表
# [(1, 4), (2, 5), (3, 6)]
# >>> zip(a,c) # 元素个数与最短的列表一致
# [(1, 4), (2, 5), (3, 6)]
# >>> zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
# [(1, 2, 3), (4, 5, 6)]
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg, g_kl, d_real,d_fake,smooth_l1),flags) if f] # 当f为True时,返回对应的l,其中l表示loss值
return loss_filter ## 最后返回的是激活的loss值,False的loss值并不记录在内
#在initialize函数里面看看对Pix2PixHDModel的一些设置。
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false
self.gen_features = self.use_features and not self.opt.load_features ## it is also false
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # # # 因为label_n=0,因此赋值为3
##### define networks
# Generator network
netG_input_nc = input_nc # 输入层数
if not opt.no_instance: # # 如果有实例标签,则通道加1
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
# 创建全局生成网络
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) ## train_options里设置了opt.num_D=2
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
print("---------- G Networks reloaded -------------")
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
print("---------- D Networks reloaded -------------")
if self.gen_features:
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0!
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) # # 训练G和D的损失函数定义
self.criterionFeat = torch.nn.L1Loss() ## feature matching损失项的定义,使用的是L1 loss。
#
## self.criterionImage = torch.nn.SmoothL1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) ## percetual loss的定义。这是可选项,对最终结果也有帮助。
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG', 'G_KL', 'D_real', 'D_fake', 'Smooth_L1') # 利用loss过滤器返回有用的loss名字
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
print("---------- Optimizers initialized -------------")
if opt.continue_train:
self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)
self.load_optimizer(self.optimizer_G, "G", opt.which_epoch)
for param_groups in self.optimizer_D.param_groups:
self.old_lr=param_groups['lr']
print("---------- Optimizers reloaded -------------")
print("---------- Current LR is %.8f -------------"%(self.old_lr))
## We also want to re-load the parameters of optimizer.
#
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
# 1、label_map 数据类型转化
if self.opt.label_nc == 0: # # 如果label通道为0,那么直接转为cuda张量
input_label = label_map.data.cuda()
else: # # 否则为标签图创建onehot的label映射
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) ## size[0]为类别数量;size[2], size[3]为标签尺寸
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16: ## 源代码数据集的label是16位的
input_label = input_label.half() ### 数据类型转化
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map) # # 获取边缘
input_label = torch.cat((input_label, edge_map), dim=1) ## torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
## 定义判别器
def discriminate(self, input_label, test_image, use_pool=False):
if input_label is None:
input_concat = test_image.detach()
else:
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, label, inst, image, feat, infer=False): ## Pix2PixHDModel()默认调用forward()函数
# Encode Inputs
# 1、Encode Inputs是model.forward的第一句,目的是做一下预处理,得到one hot编码
# 同时将instance map转换为edge,转换方法是对比四领域之间的差异,和论文方法一致。
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
# 2、Fake Generation # 生成假图
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
hiddens = self.netG.forward(input_concat, 'enc')
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
# We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)
fake_image = self.netG.forward(hiddens + noise, 'dec')
if self.opt.no_cgan:
# Fake Detection and Loss
# 3、Fake Detection and Loss # 输入标签图和假图,计算loss值
pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False) ## 返回损失函数
# Real Detection and Loss
# 4、Real Detection and Loss # 输入标签图和真图,计算loss值
pred_real = self.discriminate(None, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# G和D的损失,G的损失只有一项,就是loss_G_GAN ,D的损失有两项,loss_D_real 和loss_D_fake
# GAN loss (Fake Passability Loss)
# 5、GAN loss (Fake Passability Loss) #
pred_fake = self.netD.forward(fake_image)
loss_G_GAN = self.criterionGAN(pred_fake, True)
else:
# Fake Detection and Loss
# 3、Fake Detection and Loss # 输入标签图和假图,计算loss值
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False) ## 返回损失函数
# Real Detection and Loss
# 4、Real Detection and Loss # 输入标签图和真图,计算loss值
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# G和D的损失,G的损失只有一项,就是loss_G_GAN ,D的损失有两项,loss_D_real 和loss_D_fake
# GAN loss (Fake Passability Loss)
# 5、GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl
# GAN feature matching loss
# feature matching的损失的计算方式,就是计算相同的位置上,假样本和真样本的特征的L1距离。
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
# 计算公式
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
smooth_l1_loss=0
# # 返回loss值以及假图片,因为train函数里赋值infer=save_fake
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,smooth_l1_loss ), None if not infer else fake_image ]
# 前向推理
def inference(self, label, inst, image=None, feat=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# Fake Generation
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
if self.opt.data_type==16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num+1))
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
def get_edges(self, t): # # 边缘提取算子
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
if self.opt.data_type==16:
return edge.half()
else:
return edge.float()
# 保存模型参数
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
self.save_optimizer(self.optimizer_G,"G",which_epoch)
self.save_optimizer(self.optimizer_D,"D",which_epoch)
if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
def update_fixed_params(self):
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
# 衰减学习率
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
class InferenceModel(Pix2PixHDModel): # # 推理模型,仅前向传播
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)
pix2pixHD_model_DA.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
################################################################################################
# 看到Pix2PixHDModel的类。这个类的内容非常多,有搭建模型,定义优化器和损失函数,导入模型等操作。
################################################################################################
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
# loss过滤器:其中g_gan、d_real、d_fake三个loss值是肯定返回的
# 至于g_gan_feat,g_vgg两个loss值根据train_options的opt.no_ganFeat_loss, not opt.no_vgg_loss而定
# 备注:这个函数只是一个过滤器,不仅可以滤掉loss值,也可以滤掉loss name,主要看是谁在调用,输入什么,就可以滤掉什么。
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake):
### zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
# 如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。
# >>>a = [1,2,3]
# >>> b = [4,5,6]
# >>> c = [4,5,6,7,8]
# >>> zipped = zip(a,b) # 打包为元组的列表
# [(1, 4), (2, 5), (3, 6)]
# >>> zip(a,c) # 元素个数与最短的列表一致
# [(1, 4), (2, 5), (3, 6)]
# >>> zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
# [(1, 2, 3), (4, 5, 6)]
return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake), flags) if f] # 当f为True时,返回对应的l,其中l表示loss值
return loss_filter # 最后返回的是激活的loss值,False的loss值并不记录在内
## 在initialize函数里面看看对Pix2PixHDModel的一些设置。
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false
self.gen_features = self.use_features and not self.opt.load_features ## it is also false
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # # 因为label_n=0,因此赋值为3
##### define networks
# Generator network
netG_input_nc = input_nc # 输入层数
if not opt.no_instance: # 如果有实例标签,则通道加1
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt,opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
self.feat_D=networks.define_D(64, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,
1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
print("---------- G Networks reloaded -------------")
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
self.load_network(self.feat_D, 'feat_D', opt.which_epoch, pretrained_path)
print("---------- D Networks reloaded -------------")
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0!
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) # 训练G和D的损失函数定义
self.criterionFeat = torch.nn.L1Loss() # feature matching损失项的定义,使用的是L1 loss。
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) # percetual loss的定义。这是可选项,对最终结果也有帮助。
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_KL', 'D_real', 'D_fake', 'G_featD', 'featD_real','featD_fake') # 利用loss滤波器返回有用的loss名字
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
params = list(self.feat_D.parameters())
self.optimizer_featD = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
print("---------- Optimizers initialized -------------")
if opt.continue_train:
self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)
self.load_optimizer(self.optimizer_G, "G", opt.which_epoch)
self.load_optimizer(self.optimizer_featD,'featD',opt.which_epoch)
for param_groups in self.optimizer_D.param_groups:
self.old_lr = param_groups['lr']
print("---------- Optimizers reloaded -------------")
print("---------- Current LR is %.8f -------------" % (self.old_lr))
## We also want to re-load the parameters of optimizer.
#
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
# 1、label_map 数据类型转化
if self.opt.label_nc == 0: # 如果label通道为0,那么直接转为cuda张量
input_label = label_map.data.cuda()
else: # 否则为标签图创建onehot的label映射
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) # size[0]为类别数量;size[2], size[3]为标签尺寸
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16: # 源代码数据集的label是16位的
input_label = input_label.half() # 数据类型转化
#2 get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map) # 获取边缘
input_label = torch.cat((input_label, edge_map), dim=1) # torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。
input_label = Variable(input_label, volatile=infer)
# 3 real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# 4 instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
# 定义判别器
def discriminate(self, input_label, test_image, use_pool=False):
if input_label is None:
input_concat = test_image.detach()
else:
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def feat_discriminate(self,input):
return self.feat_D.forward(input.detach())
def forward(self, label, inst, image, feat, infer=False): # Pix2PixHDModel()默认调用forward()函数
# Encode Inputs
# 1、Encode Inputs是model.forward的第一句,目的是做一下预处理,得到one hot编码
# 同时将instance map转换为edge,转换方法是对比四领域之间的差异,和论文方法一致。
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
# 2、Fake Generation # 生成假图
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
hiddens = self.netG.forward(input_concat, 'enc')
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
# We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)
fake_image = self.netG.forward(hiddens + noise, 'dec')
####################
##### GAN for the intermediate feature
real_old_feat =[]
syn_feat = []
for index,x in enumerate(inst):
if x==1:
real_old_feat.append(hiddens[index].unsqueeze(0))
else:
syn_feat.append(hiddens[index].unsqueeze(0))
L=min(len(real_old_feat),len(syn_feat))
real_old_feat=real_old_feat[:L]
syn_feat=syn_feat[:L]
real_old_feat=torch.cat(real_old_feat,0)
syn_feat=torch.cat(syn_feat,0)
pred_fake_feat=self.feat_discriminate(real_old_feat)
loss_featD_fake = self.criterionGAN(pred_fake_feat, False)
pred_real_feat=self.feat_discriminate(syn_feat)
loss_featD_real = self.criterionGAN(pred_real_feat, True)
pred_fake_feat_G=self.feat_D.forward(real_old_feat)
loss_G_featD=self.criterionGAN(pred_fake_feat_G,True)
#####################################
if self.opt.no_cgan:
# Fake Detection and Loss
# 3、Fake Detection and Loss # 输入标签图和假图,计算loss值
pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False) # 返回损失函数
# Real Detection and Loss
# 4、Real Detection and Loss # 输入标签图和真图,计算loss值
pred_real = self.discriminate(None, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# G和D的损失,G的损失只有一项,就是loss_G_GAN ,D的损失有两项,loss_D_real 和loss_D_fake 。
#5 GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(fake_image)
loss_G_GAN = self.criterionGAN(pred_fake, True)
else:
# Fake Detection and Loss
# 3、Fake Detection and Loss # 输入标签图和假图,计算loss值
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
# 4、Real Detection and Loss # 输入标签图和真图,计算loss值
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
# 5、GAN loss (Fake Passability Loss) #
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl
# feature matching的损失的计算方式,就是计算相同的位置上,假样本和真样本的特征的L1距离。
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i]) - 1):
# 计算公式
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j],
pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
# Only return the fake_B image if necessary to save BW
return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,loss_G_featD, loss_featD_real, loss_featD_fake),
None if not infer else fake_image] # 返回loss值以及假图片,因为train函数里赋值infer=save_fake
# 前向推理
def inference(self, label, inst, image=None, feat=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# Fake Generation
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i // 1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k]
if self.opt.data_type == 16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num + 1))
for i in np.unique(inst_np):
label = i if i < 1000 else i // 1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num // 2, :]
val = np.zeros((1, feat_num + 1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
def get_edges(self, t): # 边缘提取算子
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
if self.opt.data_type == 16:
return edge.half()
else:
return edge.float()
# 保存模型参数
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
self.save_network(self.feat_D,'featD',which_epoch,self.gpu_ids)
self.save_optimizer(self.optimizer_G, "G", which_epoch)
self.save_optimizer(self.optimizer_D, "D", which_epoch)
self.save_optimizer(self.optimizer_featD,'featD',which_epoch)
if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
# 衰减学习率
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_featD.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
# 推理模型,仅前向传播
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)
mapping_model.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import functools
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import math
from .NonLocal_feature_mapping_model import *
class Mapping_Model(nn.Module):
def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None):
super(Mapping_Model, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
print("Mapping: You are using the mapping model without global restoration.")
for i in range(n_up):
ic = min(tmp_nc * (2 ** i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
for i in range(n_blocks):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class Pix2PixHDModel_Mapping(BaseModel):
def name(self):
return "Pix2PixHDModel_Mapping"
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2):
flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2)
#返回选择使用的loss列表
def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2):
return [
l
for (l, f) in zip(
(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags
)
if f
]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != "none" or not opt.isTrain:
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
#input_nc表示输入图片通道、label_nc表示输入标签通道
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network, 应该是模型中的三个VAE,其中有两个相同
netG_input_nc = input_nc
#参数ngf表示生成器的第一层卷积核数目
#参数n_downsample_global表示生成器网络中下采样层的数目
#参数get_norm_layer表示选择归一化的方式
self.netG_A = networks.GlobalGenerator_DCDCv2(
netG_input_nc,
opt.output_nc,
opt.ngf,
opt.k_size,
opt.n_downsample_global,
#opt.norm表示实例规范化或批处理规范化类型
networks.get_norm_layer(norm_type=opt.norm),
opt=opt,
)
self.netG_B = networks.GlobalGenerator_DCDCv2(
netG_input_nc,
opt.output_nc,
opt.ngf,
opt.k_size,
opt.n_downsample_global,
networks.get_norm_layer(norm_type=opt.norm),
opt=opt,
)
#non_local表示非局部设置;NL_use_mask表示在使用非局部映射模型时是否使用使用掩码
if opt.non_local == "Setting_42" or opt.NL_use_mask:
#mapping_exp默认值为0: original PNL;当为1: Multi-Scale Patch Attention
if opt.mapping_exp==1:
self.mapping_net = Mapping_Model_with_mask_2(
min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc),
opt.map_mc,
n_blocks=opt.mapping_n_block,
opt=opt,
)
else:
self.mapping_net = Mapping_Model_with_mask(
min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc),
opt.map_mc,
n_blocks=opt.mapping_n_block,
opt=opt,
)
else:
self.mapping_net = Mapping_Model(
min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc),
opt.map_mc,
n_blocks=opt.mapping_n_block,
opt=opt,
)
self.mapping_net.apply(networks.weights_init)
if opt.load_pretrain != "":
self.load_network(self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain)
if not opt.no_load_VAE:
self.load_network(self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA)
self.load_network(self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB)
for param in self.netG_A.parameters():
param.requires_grad = False
for param in self.netG_B.parameters():
param.requires_grad = False
self.netG_A.eval()
self.netG_B.eval()
if opt.gpu_ids:
self.netG_A.cuda(opt.gpu_ids[0])
self.netG_B.cuda(opt.gpu_ids[0])
self.mapping_net.cuda(opt.gpu_ids[0])
if not self.isTrain:
self.load_network(self.mapping_net, "mapping_net", opt.which_epoch)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1, opt.use_two_stage_mapping)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
self.criterionFeat_feat = torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss()
if self.opt.image_L1:
self.criterionImage=torch.nn.L1Loss()
else:
self.criterionImage = torch.nn.SmoothL1Loss()
print(self.criterionFeat_feat)
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_Feat_L2', 'G_GAN', 'G_GAN_Feat', 'G_VGG','D_real', 'D_fake', 'Smooth_L1', 'G_Feat_L2_Stage_1')
# initialize optimizers
# optimizer G
if opt.no_TTUR:
beta1,beta2=opt.beta1,0.999
G_lr,D_lr=opt.lr,opt.lr
else:
beta1,beta2=0,0.9
G_lr,D_lr=opt.lr/2,opt.lr*2
if not opt.no_load_VAE:
params = list(self.mapping_net.parameters())
self.optimizer_mapping = torch.optim.Adam(params, lr=G_lr, betas=(beta1, beta2))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2))
print("---------- Optimizers initialized -------------")
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, label, inst, image, feat, pair=True, infer=False, last_label=None, last_image=None):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
input_concat = input_label
label_feat = self.netG_A.forward(input_concat, flow='enc')
# print('label:')
# print(label_feat.min(), label_feat.max(), label_feat.mean())
#label_feat = label_feat / 16.0
if self.opt.NL_use_mask:
label_feat_map=self.mapping_net(label_feat.detach(),inst)
else:
label_feat_map = self.mapping_net(label_feat.detach())
fake_image = self.netG_B.forward(label_feat_map, flow='dec')
image_feat = self.netG_B.forward(real_image, flow='enc')
loss_feat_l2_stage_1=0
loss_feat_l2 = self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat
if self.opt.feat_gan:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(label_feat.detach(), label_feat_map, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(label_feat.detach(), image_feat)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((label_feat.detach(), label_feat_map), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
else:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
if pair:
pred_real = self.discriminate(input_label, real_image)
else:
pred_real = self.discriminate(last_label, last_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss and pair:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
tmp = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
loss_G_GAN_Feat += D_weights * feat_weights * tmp
else:
loss_G_GAN_Feat = torch.zeros(1).to(label.device)
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat if pair else torch.zeros(1).to(label.device)
smooth_l1_loss=0
if self.opt.Smooth_L1:
smooth_l1_loss=self.criterionImage(fake_image,real_image)*self.opt.L1_weight
return [ self.loss_filter(loss_feat_l2, loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake,smooth_l1_loss,loss_feat_l2_stage_1), None if not infer else fake_image ]
def inference(self, label, inst):
use_gpu = len(self.opt.gpu_ids) > 0
if use_gpu:
input_concat = label.data.cuda()
inst_data = inst.cuda()
else:
input_concat = label.data
inst_data = inst
label_feat = self.netG_A.forward(input_concat, flow="enc")
if self.opt.NL_use_mask:
if self.opt.inference_optimize:
label_feat_map=self.mapping_net.inference_forward(label_feat.detach(),inst_data)
else:
label_feat_map = self.mapping_net(label_feat.detach(), inst_data)
else:
label_feat_map = self.mapping_net(label_feat.detach())
fake_image = self.netG_B.forward(label_feat_map, flow="dec")
return fake_image
class InferenceModel(Pix2PixHDModel_Mapping):
def forward(self, label, inst):
return self.inference(label, inst)
NonLocal_feature_mapping_model.py
略