图像风格迁移其实非常好理解,就是将一张图像的“风格”(风格图像)迁移至另外一张图像(内容图像),但是这所谓的另外一张图像只是在“风格”上与之前有所不同,图像的“内容”仍要与之前相同。Luan et al. and Gatys et al.
的工作都是利用VGGNet19
作为该项任务的backbone
,由于VGGNet19
是一种近似“金字塔”型结构,所以随着卷积操作的加深,feature maps
的感受野越来越大,提取到的图像特征从局部扩展到了全局。我们为了避免合成的图像过多地保留内容信息,选取VGGNet19
中位于金字塔顶部的卷积层作为内容层。整个训练过程为将生成图像初始化为内容图像,每次循环分别抽取生成图像和内容图像的内容特征,计算mse并且使之最小化,同时抽取生成图像和风格图像的样式特征,计算mse并且使之最小化。这里注意损失函数的写法:
总损失由两部分组成:内容损失和样式损失。内容损失即为生成图像和内容图像对应特征图的均方误差,但是样式损失需要分别计算生成图像和内容图像的格拉姆矩阵再做均方误差。另外, α \alpha α和 β \beta β分别为内容损失和样式损失的各项权重, Γ \Gamma Γ为样式损失的惩罚系数。我通过实发现 β \beta β和 Γ \Gamma Γ应该取的值大些,使得样式损失被尽可能地“惩罚”,即“放大”样式损失。
import torch
import numpy as np
from PIL import Image
from torchvision.models import vgg19
from torchvision.transforms import transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.nn.functional import mse_loss
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 预处理:大小裁剪、转为张量、归一化
def preprocess(img_shape):
transform = transforms.Compose([
transforms.Resize(img_shape),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
return transform
class VGGNet19(nn.Module):
def __init__(self):
super(VGGNet19, self).__init__()
self.vggnet19 = vgg19(pretrained=False)
self.vggnet19.load_state_dict(torch.load('./vgg19-dcbb9e9d.pth'))
self.content_layers = [25]
self.style_layers = [0, 5, 10, 19, 28]
def forward(self, x):
content_features = []
style_features = []
for name, module in self.vggnet19.features._modules.items():
x = module(x)
if int(name) in self.content_layers:
content_features.append(x)
if int(name) in self.style_layers:
style_features.append(x)
return content_features, style_features
class GenerateImage(nn.Module):
def __init__(self, img_shape):
super(GenerateImage, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
# 初始化生成图像为内容图像
def generate_inits(content, device, lr):
g_img = GenerateImage(content.shape).to(device)
g_img.weight.data = content.data
optimizer = torch.optim.Adam(g_img.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
return g_img(), optimizer
# 计算格拉姆矩阵
def gramMatrix(x):
_, c, h, w = x.shape
x = x.view(c, h*w)
return torch.matmul(x, x.t()) / (c*h*w)
# 计算总损失:内容损失+样式损失
def compute_loss(content_g, content_y, style_g, style_y, content_weight, style_weight, gamma):
contentlosses = [mse_loss(g, y)*content_weight for g, y in zip(content_g, content_y)]
stylelosses = [mse_loss(gramMatrix(g), gramMatrix(y))*style_weight for g, y in zip(style_g, style_y)]
total_loss = sum(contentlosses) + gamma * sum(stylelosses)
return contentlosses, stylelosses, total_loss
# 用于可视化的后处理
def postprocess(img_tensor):
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
inv_normalize = transforms.Normalize(
mean=-rgb_mean/rgb_std,
std=1/rgb_std)
to_PIL_image = transforms.ToPILImage()
return to_PIL_image(inv_normalize(img_tensor[0].detach().cpu()).clamp(0, 1))
def train(lr, epoch_num, c_path, s_path, img_shape):
ipt = Image.open(c_path)
syl = Image.open(s_path)
transform = preprocess(img_shape)
content, style = transform(ipt).unsqueeze(0), transform(syl).unsqueeze(0)
net = VGGNet19()
net.to(device).eval()
content = content.type(torch.FloatTensor)
style = style.type(torch.FloatTensor)
if torch.cuda.is_available():
content, style = Variable(content.cuda(), requires_grad=False), Variable(style.cuda(), requires_grad=False)
else:
content, style = Variable(content, requires_grad=False), Variable(style, requires_grad=False)
icontent, istyle = net(content)
scontent, sstyle = net(style)
input, optimizer = generate_inits(content, device, lr)
for epoch in range(epoch_num+1):
gcontent, gstyle = net(input)
contentlosses, stylelosses, total_loss = compute_loss(gcontent, icontent, gstyle, sstyle, 1, 1e3, 1e2)
optimizer.zero_grad()
total_loss.backward(retain_graph=True)
optimizer.step()
print("[epoch: %3d/%3d] content loss: %3f style loss: %3f total loss: %3f" % (epoch, epoch_num, sum(contentlosses).item(), sum(stylelosses).item(), total_loss))
if epoch % 100 == 0 and epoch != 0:
# plt.imshow(postprocess(input))
# plt.axis('off')
# plt.show()
torch.save(net.state_dict(), "itr_%d_total_loss_%3f.pth" % (epoch, total_loss))
if __name__ == "__main__":
train(0.01, 10000, './content.jpg', './s.jpg', (500, 700)
内容图像、风格图像和生成图像(第10000次迭代的可视化)分别如上图所示,并且代码实现是Gatys et al.
的工作。