李沐d2l《动手学深度学习》第二版——风格迁移源码详解

本文是对李沐Dive to DL《动手学深度学习》第二版13.12节风格迁移的源码详解,整体由Jupyter+VSCode完成,几乎所有重要代码均给出了注释,一看就懂。需要的同学可以在文末链接处下载原文件。

话不多说,我们开始。

图1中的内容图像为本书作者在西雅图郊区的雷尼尔山国家公园拍摄的风景照,而风格图像则是一幅主题为秋天橡树的油画。最终输出的合成图像应用了风格图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。
李沐d2l《动手学深度学习》第二版——风格迁移源码详解_第1张图片

图1 风格迁移效果示意图

1.一种简单的风格迁移方法

(1) 初始化合成图像,例如将其初始化为内容图像(content image);
(2) 利用预训练网络(如VGG-19)的某些层抽取内容图像与合成图像的内容特征,再用某些层抽取风格图像与合成图像的风格特征;
(3) 根据抽取出来的content feature map和style feature map计算出内容损失(content loss,使合成图像与内容图像在内容特征上接近)和风格损失(style loss,使合成图像与风格图像在风格特征上接近);
(4) 根据当前的合成图像自身计算出全变分损失(total variation loss,有助于减少合成图像中的噪点);
(5) 将这三个损失按一定比例加权(主观更倾向于合成什么样的图像),计算出最终的总损失;
(6) 根据损失反向传播误差,逐步更新合成图像的参数,降低损失,最终结束训练,图像风格迁移成功。

李沐d2l《动手学深度学习》第二版——风格迁移源码详解_第2张图片

图2 风格迁移的三类损失

2.先观察一下内容图像(content image)和风格图像(style image)

导包:

import torch
import torchvision
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt

读取图像并显示:

content_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\rainier.jpg')
style_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\autumn_oak.jpg')
plt.imshow(content_img)
plt.show()  # Display all open figures.
plt.imshow(style_img)
plt.show()

3.图像预处理和后处理

预处理函数preprocess()对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入 格式。
后处理函数postprocess()则将输出图像反标准化,输出能正常显示的人眼看得懂的图像。

# content_img[1364, 2047, 2]  # imread读取的图像前两个维度是高和宽,第三个维度表示选择RGB三个通道中的哪个通道

# ImageNet先验归一化
# 该均值和标准差来源于ImageNet数据集统计得到,如果建立的数据集分布和ImageNet数据集数据分布类似(来自生活真实场景,例如人像、风景、交通工具等),或者使用PyTorch提供的预训练模型,推荐使用该参数归一化。如果建立的数据集并非是生活真实场景(如生物医学图像),则不推荐使用该参数
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

# 在使用深度学习框架构建训练数据时,通常需要数据归一化,以利于网络的训练
# 输入图像必须是PIL/np.ndaray
def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        # ToTensor()将其图像由(h,w,c)转置为(c,h,w),再把像素值从[0,255]变换到[0,1]
        torchvision.transforms.ToTensor(),
        # 标准化
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    return transforms(img).unsqueeze(0) # unsqueeze将图像升至4维,增加批次数=1,便于后续图像处理可以更好地进行批操作

# 在训练过程可视化中,通常需要反归一化,以显示能用人眼看得懂的正常的图
def postprocess(img):
    # 将图像从训练的GPU环境移至CPU并转为3维(c,h,w)(消去第四维batch)
    img = img[0].to(rgb_std.device)
    # 先将(c,h,w)转化为(h,w,c)实施反归一化,并将输出范围限制到[0,1]
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) 
    # 再将(h,w,c)转化为(c,h,w),ToPILImage()将Tensor的每个元素乘以255;将数据由Tensor转化成Uint8
    # ToPILImage()要求输入图像若是tensor,则shape必须是(c,h,w)形式
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

我们可以写一段测试代码看看postprocess()干了些啥:

# 这里content_img是PIL格式
# img是tensor格式
img = preprocess(content_img, (224,244))
img.shape    # 输出:torch.Size([1, 3, 224, 244])
img = img[0].to(rgb_std.device)
img.shape    # 输出:torch.Size([3, 224, 244])
img = img.permute(1, 2, 0)
img.shape    # 输出:torch.Size([224, 244, 3])
img = torch.clamp(img * rgb_std + rgb_mean, 0, 1)
img_before = torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
plt.imshow(img_before)   # PIL格式,不可索引单个像素,不可输出shape

输出标准化+反标准化后的图像:
李沐d2l《动手学深度学习》第二版——风格迁移源码详解_第3张图片

4. 抽取图像特征

我们使用基于ImageNet数据集预训练的VGG-19模型来抽取图像特征 [Gatys et al., 2016]。

pretrained_net = torchvision.models.vgg19(pretrained=True)
pretrained_net # 输出vgg-19的网络结构:5个卷积块,前两个块中有2个卷积层,后三个块中有4个卷积层

为了抽取图像的内容特征和风格特征,我们可以选择VGG网络中某些层的输出。 一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。 为了避免合成图像过多保留内容图像的细节(我们只需要保留一个大概的主题及轮廓即可,细节方面由风格特征把握),我们选择VGG较靠近输出的层来输出图像的内容特征。 我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。 VGG网络使用了5个卷积块,实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。 这些层的索引可以通过打印pretrained_net实例获取。

style_idx = [0, 5, 10, 19, 28] # 各卷积块中的第一个卷积层的索引分别是Sequential中的0, 5, 10, 19, 28
content_idx = [25]             # 第4个卷积块中最后一个卷积层的索引是Sequential中的25

# 构建一个新的网络net,它只保留需要用到的VGG的所有层。
# 因为用到最深的层是第28层,因此我们要保留vgg网络中第28层及前面的所有层,而28层以后的均不要保留
# pretrained_net.features输出vgg网络的feature属性(即平均池化之前的所有层)
layers = []
# max(content_layers + style_layers)求用到的最深层的索引
for i in range(max(style_idx + content_idx) + 1):
    # 逐层加入到所需的layers列表中
    layers.append(pretrained_net.features[i])
# 将layers逐元素加入到Sequential中
net = nn.Sequential(*layers)
net
# 风格特征提取层如下:
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# 内容特征提取层如下:
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

给定输入X,如果我们简单地调用前向传播net(X),只能获得最后一层的输出。 由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出。

def extract_features(x, content_idx, style_idx):
    content_features = []
    style_features = []
    for i in range(len(net)):
        # 计算当前层输出
        temp_layer = net[i]
        x = temp_layer(x)
        # 若当前层索引在风格索引列表中
        if i in style_idx:
            style_features.append(x)
        # 若当前层索引在内容索引列表中
        if i in content_idx:
            content_features.append(x)
    return content_features, style_features

下面定义两个函数:get_contents函数对内容图像抽取内容特征; get_styles函数对风格图像抽取风格特征。 因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和风格特征。 由于合成图像是风格迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和风格特征。

# 提取内容图像的内容特征
def get_content_features(content_img, image_shape):
    # 对content_img先进行预处理并移至gpu,便于直接输入网络
    content_x = preprocess(content_img, image_shape).cuda()
    # 提取内容图像的内容特征
    content_features_x, _ = extract_features(content_x, content_idx, style_idx)
    return content_x, content_features_x

# 提取风格图像的风格特征
def get_style_features(style_img, image_shape):
    # 对style_img先进行预处理并移至gpu,便于直接输入网络
    style_x = preprocess(style_img, image_shape).cuda()
    # 提取风格图像的风格特征
    _, style_features_x = extract_features(style_x, content_idx, style_idx)
    return style_x, style_features_x

5.定义损失函数:由内容损失、风格损失和全变分损失3部分组成

5.1. 内容损失
内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异。 平方误差函数的两个输入均为extract_features函数计算所得到的内容层的输出。

def calc_contentloss(Y_hat, Y):
    # 从动态计算梯度的树中分离目标
    # 计算所有通道对应矩阵的差的平方和,再除以所有元素个数
    # 这里把Y detach一下是因为原始内容图像的特征图无需参与反向传播(视为已知常量),所以将它从计算图中分离,否则的话反向更新会影响该值
    return torch.square(Y_hat - Y.detach()).mean()
   

测试一下这函数干了啥:

Y_hat = torch.randn(512,2,2)
Y = torch.randn(512,2,2)
torch.square(Y_hat - Y.detach()).mean() == torch.square(Y_hat - Y.detach()).sum() / 2048
# 返回tensor(True)

5.2. 风格损失
风格损失与内容损失类似,也通过平方误差函数衡量合成图像与风格图像在风格上的差异。 为了表达风格层输出的风格,我们先通过extract_features函数计算风格层的输出。 假设该输出的样本数为1,通道数为 c c c,高和宽分别为 h h h w w w ,我们可以将此输出转换为矩阵 X X X,其有 c c c行和 h w hw hw列(相当把一个通道的矩阵拉成一个行向量)。 这个矩阵可以被看作是由 c c c个长度为 h w hw hw的向量 x 1 , . . . , x c x_{1},...,x_{c} x1,...,xc组合而成的。其中向量 x i x_{i} xi代表了通道 i i i上的风格特征(其实就是该通道的所有像素点)。
在这些向量的格拉姆矩阵 G = X X ⊤ ∈ R c × c G = XX^⊤∈R^{c×c} G=XXRc×c中, i i i j j j列的元素 x i j x_{ij} xij即向量 x i x_{i} xi x j x_{j} xj的内积。它表达了通道 i i i和通道 j j j上风格特征的相关性(emmm…姑且认为一个像素代表一个特征吧)。我们用这样的格拉姆矩阵来表达风格层输出的风格。 需要注意的是,当 h w hw hw的值较大时,格拉姆矩阵中的元素容易出现较大的值。 此外,格拉姆矩阵的高和宽皆为通道数 c 。 为了让风格损失不受这些值的大小影响,下面定义的calc_gram函数将格拉姆矩阵除以了矩阵中元素的个数,即 c h w chw chw

# 输入是vgg某层输出的特征图,尺寸为(c,h,w)
def calc_gram(x):
    c = x.shape[1]    # c是输出的风格特征图的通道数
    hw = x.shape[2] * x.shape[3]    # hw是一张特征图矩阵中所有元素的个数
    x = x.reshape((c, hw))    # 将(c,h,w)变换为(c,h*w)
    return torch.matmul(x, x.T) / (c * hw)    # matmul是矩阵乘法

# 计算风格损失,这里假设风格图像的格拉姆矩阵已经提前计算好了
def calc_styleloss(Y_hat, gram_Y):
    # 这里把gram_Y detach一下是因为原始风格图像的格拉姆矩阵无需参与反向传播(视为已知常量),所以将它从计算图中分离,否则的话反向更新会影响该值
    return torch.square(calc_gram(Y_hat) - gram_Y.detach()).mean()

5.3. 全变分损失
有时候,我们学到的合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。 一种常见的去噪方法是全变分去噪(total variation denoising): 假设 x i , j x_{i,j} xi,j表示坐标 ( i , j ) (i,j) (i,j)处的像素值,则全变分损失定义为:
∑ i , j ∣ x i , j − x i + 1 , j ∣ + ∑ i , j ∣ x i , j − x i , j + 1 ∣ \sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \sum_{i, j} \left|x_{i, j} - x_{i, j+1}\right| i,jxi,jxi+1,j+i,jxi,jxi,j+1
这里我们用如下公式计算单幅图像的全变分损失:
l o s s = 1 2 ( l o s s v e r t i c a l + l o s s h o r i z o n t a l ) = 1 2 ( 1 c h w ∑ c ∑ i , j ∣ x i , j − x i + 1 , j ∣ + 1 c h w ∑ c ∑ i , j ∣ x i , j − x i , j + 1 ∣ ) loss = \frac{1}{2}(loss_{vertical} + loss_{horizontal}) = \frac{1}{2}(\frac{1}{chw}\sum_{c}\sum_{i, j}\left|x_{i, j} - x_{i+1, j}\right| + \frac{1}{chw}\sum_{c}\sum_{i, j}\left|x_{i, j} - x_{i, j+1}\right|) loss=21(lossvertical+losshorizontal)=21(chw1ci,jxi,jxi+1,j+chw1ci,jxi,jxi,j+1)
其中 ∑ c \sum_{c} c表示对各通道求和, c c c同时表示通道数。

def calc_tvloss(Y_hat):
    # [:, :, 1:, :]表示取各通道图像矩阵的第1行至最后一行(起始行为0行)
    # [:, :, :-1, :]表示取各通道图像矩阵的第0行至倒数第二行
    # 矩阵相减取绝对值再在各通道上取平均(所有元素加起来除以总元素数)
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

测试一下:

y = torch.randn(1,3,4,4)
torch.abs(y[:,:,1:,:] - y[:,:,:-1,:]).mean()
torch.abs(y[:,:,:,1:] - y[:,:,:,:-1]).mean()
# 输出 tensor(0.9967)

5.4. 损失函数
风格转移任务的损失函数是内容损失、风格损失和全变分损失的加权和。 通过调节三者的权重超参数,我们可以权衡合成图像在保留内容、保留风格及去噪三方面的相对重要性。

content_weight, style_weight, tv_weight = 1, 1000, 10
# 计算总的损失函数值
def compute_loss(X, content_Y, content_Y_hat, style_Y_gram, style_Y_hat):
    # 分别计算内容损失、风格损失和全变分损失
    # 对1对y,y_hat求内容损失,乘以权重后添加到列表中
    content_l = [calc_contentloss(Y_hat, Y) * content_weight for Y_hat, Y in zip(content_Y_hat, content_Y)]
    # 对5对y,y_hat分别求风格损失,乘以权重后添加到列表中
    style_l = [calc_styleloss(Y_hat, Y) * style_weight for Y_hat, Y in zip(style_Y_hat, style_Y_gram)]
    # 求总变差损失,乘以权重
    tv_l = calc_tvloss(X) * tv_weight
    # 对所有损失求和(5个风格损失,1个内容损失,1个总变差损失)
    l = sum(10 * style_l + content_l + [tv_l]) #style_l乘10干啥?
    return content_l, style_l, tv_l, l

6.初始化合成图像

在风格迁移中,合成的图像是训练期间唯一需要更新的变量。因此,我们可以定义一个简单的模型SynthesizedImage,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可。

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape):
        super(SynthesizedImage, self).__init__()
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight

下面,我们定义get_inits函数。该函数创建了合成图像的模型实例,并将其初始化为图像X。风格图像在各个风格层的格拉姆矩阵styles_Y_gram将在训练前预先计算好。

def get_inits(X, lr, style_Y):
    # X是内容图像的预处理结果
    gen_img = SynthesizedImage(X.shape).cuda()
    # 将初始化的weight参数改为已有的图像X的参数(即像素)
    gen_img.weight.data.copy_(X.data)
    # 定义优化器
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    # 对各风格特征图计算其格拉姆矩阵,并依次存于列表中
    style_Y_gram = [calc_gram(Y) for Y in style_Y]
    # !!!gen_img()!!!括号
    return gen_img(), style_Y_gram, trainer

7.训练模型

在训练模型进行风格迁移时,我们不断抽取合成图像的内容特征和风格特征,然后计算损失函数。下面定义了训练循环。

def train(X, content_Y, style_Y, lr, num_epochs, lr_decay_epoch):
    # X是初始化的合成图像,style_Y_gram是原始风格图像的格拉姆矩阵列表
    X, style_Y_gram, trainer = get_inits(X, lr, style_Y)
    # 定义学习率下降调节器
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    for epoch in range(num_epochs):
        trainer.zero_grad()
        # Y_hat是用合成图像计算出的特征图
        content_Y_hat, style_Y_hat = extract_features(X, content_idx, style_idx)
        content_l, style_l, tv_l, l = compute_loss(X, content_Y, content_Y_hat, style_Y_gram, style_Y_hat)
        # 反向传播误差(计算l对合成图像像素矩阵的导数,因为l的唯一自变量是合成图像像素矩阵)
        l.backward()
        # 更新一次合成图像的像素参数
        trainer.step()
        # 更新学习率超参数
        scheduler.step()
        # 每5个epoch记录一次loss信息
        if (epoch + 1) % 5 == 0:
            # 由于风格损失列表有5项,因此算个总损失输出
            print('迭代次数:{}    内容损失:{:.9f}    风格损失:{:.9f}    总变差损失:{:.9f}' .format(epoch+1, sum(content_l).item(), sum(style_l).item(), tv_l.item()))
    # 训练结束后返回合成图像
    return X

8.开始训练+输出风格迁移图像

首先导入内容图像和风格图像并进行预处理,同时事先计算好内容特征和风格特征。

content_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\rainier.jpg')
style_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\autumn_oak.jpg')
image_shape = (300, 450)
net = net.cuda()
# 计算内容图像的预处理结果(因为我们将内容图像作为合成图像的初始化图像作为网络的初始输入)和抽取到的内容特征
X, content_features_Y = get_content_features(content_img, image_shape)
# 计算风格图像抽取到的风格特征
_, style_features_Y = get_style_features(style_img, image_shape)

开始训练:

output = train(X, content_features_Y, style_features_Y, lr = 0.3, num_epochs = 500, lr_decay_epoch = 50)
# 调用后处理函数处理最终的合成图像,将其转换为正常格式的可视化图像
output = postprocess(output)
# 显示图像
plt.imshow(output)
plt.show()

结果:
李沐d2l《动手学深度学习》第二版——风格迁移源码详解_第4张图片

9.Pytorch源码

本文所用的Pytorch源码可直接跑通(环境:Win10+VScode+Cuda11.3 +CuDnn),链接如下:
Pytorch实现风格迁移源码
欢迎各位小伙伴star or fork~

你可能感兴趣的:(深度学习,计算机视觉,深度学习,计算机视觉,风格迁移,图像处理)