pytorch实战-7图像风格迁移

1 什么是风格迁移

how to:还是cnn,输入是图像,输出和上一章相比,不是数字,而是图像。

意义:给一张图像输入,可以输出指定风格化处理的图像

2 风格迁移发展简史

早期针对图像局部特征(纹理生成)或特定风格/场景建立模型,迁移时通过套用模型提取图片纹理或转化风格。缺点是特征/风格单一,无法通用。

2015 lero gatys尝试用神经网络做风格迁移,效果很好,并成为了主流。神经网络做风格迁移前,主要有纹理生成,特定风格等技术

2.1 纹理生成

此种风格迁移是将物体表面的纹理特征作为风格,赋予其他图像

早期纹理生成方法可分为3类:纹理映射,过程纹理合成,基于样图的纹理生成

纹理映射:

过程纹理合成:计算机模拟物体表面纹理直接生成

基于样图纹理生成:基于小区域样图,按表面几何形状,拼接生成整个图像。适用于给图像填补缺陷

2.2 特定风格实现

针对每种风格设计一种算法,对待处理的图像应用算法,将特定风格赋予图像。缺点是麻烦,每种风格需设计一种算法

3 神经网络风格迁移

3.1 优势

1 特征提取维度更多,比人为设算法可能覆盖面更大 2 无需为每种风格设置特定算法,一个网络可以提取很多风格

3.2 基本思想

风格迁移不是每个像素的一致,而是整体特征,整体特征可由特征图体现,特征图又由卷积核决定,所以风格迁移里风格主要影响因素是卷积核。总体来说,保持风格迁移的同时,还要保证图像一定的内容相似度

3.3 卷积神经网络的选取

目标:选取适合做风格迁移的神经网络

如何选取:因为衡量尺度有多个(内容准确度和风格相似度),可据此选取在分类问题表现良好的大型网络,一种常用的网络是牛津视觉几何组提出的VGG网络。

特点:和前述章节cnn相比,VGG卷积核较小(3x3)且网络较深。对于VGG,当固定其他参数时,逐步加深网络,网络的识别能力也会逐步提高,且有良好泛化能力。

举例:VGG19有16个卷积层和3个全连接层,16个卷积层分为5组,每组是2个或4个连续的卷积层。每组卷积层大小不变,是因为对于3x3卷积核用了padding=1

3.4 内容损失

第四组的第一层卷积层的输出对内容准确度可以有较好的体现,因此,将图像输入,将第四组第一层卷积网络输出做均方误差作为loss表示图像准确度

3.5 风格损失

风格损失可理解为特征图之间的相关性:即包括同一层特征图之间相关性,也包含不同层里的特征图的相关性,可用Gram矩阵表示。gram矩阵定义为同一层两个特征图的内积。风格损失可用gram矩阵的均方误差表示

3.6 风格损失原理分析

如果两个图像长得像但风格不同,那么gram矩阵会很不同但第四层特征图均方误差相近

如果两个风格相同的图但内容很不一样,那么gram矩阵会很接近但第四层特征图均方误差很不一样

3.7 风格迁移损失函数优化

综合loss可通过内容损失和风格损失加权体现,参数通过训练寻找最优。训练过程不会改变神经网络参数,只会改变加权的权重参数

4 案例

4.1 准备工作

主要包括导入模块,准备数据,数据预处理等操作

from __future__ import print_function

import os.path
import copy
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as pyplot

import torchvision.transforms as transforms
import torchvision.models as models

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

origin_image_path = os.path.realpath('./pytorch/jizhi/feature_transform/origin_image.jpg')
styled_image_path = os.path.realpath('./pytorch/jizhi/feature_transform/styled_image.jpg')

style_loss_weight = 1000
content_loss_weight = 1

img_size = 128

loader = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
unloader = transforms.ToPILImage()
pyplot.ion()

def image_loader(img_name):
    image = Image.open(img_name)
    image = loader(image).clone().detach().requires_grad_(True)
    image = image.unsqueeze(0)
    return image

def img_show(image, title=None):
    image = image.clone().cpu()
    image = image.view(3, img_size, img_size)
    image = unloader(image)
    pyplot.imshow(image)
    if not title:
        pyplot.title(title)
    pyplot.pause(0.001)

def main():
    styled_image = image_loader(styled_image_path).type(dtype)
    origin_image = image_loader(origin_image_path).type(dtype)
    if styled_image.size() != origin_image.size():
        raise Exception('origin image size is not equal to styled image')
    # image show
    pyplot.figure()
    img_show(styled_image.data, title='Styled Image')
    pyplot.figure()
    img_show(origin_image.data, title='Content Image')
    
    # neural net build
    cnn = models.vgg19(pretrained=True).features
    cnn = cnn.cuda() if use_cuda else cnn
    
    
if __name__ == '__main__':
    main()

4.2 定义内容损失和风格损失

from __future__ import print_function

import os.path
import copy
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as pyplot

import torchvision.transforms as transforms
import torchvision.models as models

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

origin_image_path = os.path.realpath('./pytorch/jizhi/feature_transform/origin_image.jpg')
styled_image_path = os.path.realpath('./pytorch/jizhi/feature_transform/styled_image.jpg')

style_loss_weight = 1000
content_loss_weight = 1

img_size = 128

loader = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
unloader = transforms.ToPILImage()
pyplot.ion()

class ContentLoss(nn.Module):
    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.criterion = nn.MSELoss()
        
    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
    
class StyledLoss(nn.Module):
    def __init__(self, target, weight):
        super(StyledLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.gram = GramMatrix()
        self.criterion = nn.MSELoss()
        
    def forward(self, input):
        self.output = input.clone()
        input = input.cuda() if use_cuda else input
        self_G = Gram(input)
        self_G.mul_(self.weight)
        self.loss = self.criterion(self_G, self.target)
        return self.output
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
        
def Gram(input):
    batch_num, map_num, map_length, map_width = input.size()
    features = input.view(batch_num * map_num, map_length * map_width)
    G = torch.mm(features, features.t())
    return G.div(batch_num * map_num * map_length * map_width)

def image_loader(img_name):
    image = Image.open(img_name)
    image = loader(image).clone().detach().requires_grad_(True)
    image = image.unsqueeze(0)
    return image

def img_show(image, title=None):
    image = image.clone().cpu()
    image = image.view(3, img_size, img_size)
    image = unloader(image)
    pyplot.imshow(image)
    if not title:
        pyplot.title(title)
    pyplot.pause(0.001)

def main():
    styled_image = image_loader(styled_image_path).type(dtype)
    origin_image = image_loader(origin_image_path).type(dtype)
    if styled_image.size() != origin_image.size():
        raise Exception('origin image size is not equal to styled image')
    # image show
    pyplot.figure()
    img_show(styled_image.data, title='Styled Image')
    pyplot.figure()
    img_show(origin_image.data, title='Content Image')
    
    # neural net build
    cnn = models.vgg19(pretrained=True).features
    cnn = cnn.cuda() if use_cuda else cnn
    
    content_layers = ['conv_4']
    styled_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
    
if __name__ == '__main__':
    main()

detach可以将tensor与计算图解耦,即计算图forward和backward时不会再改变tensor的grad

4.3 建立计算图

    origin_losses = []
    style_losses = []
    
    model = nn.Sequential().cuda() if use_cuda else nn.Sequential()
    i = 1
    for layer in list(cnn):
        if isinstance(layer, nn.Conv2d):
            name = 'conv_' + str(i)
            model.add_module(name, layer)
            if name in content_layers:
                target = model(origin_image).clone()
                origin_image_loss = ContentLoss(target, content_loss_weight)
                origin_image_loss = origin_image_loss.cuda() if use_cuda else origin_image_loss
                origin_losses.append(origin_image_loss)
            if name in styled_layers:
                target_feature = model(styled_image).clone()
                target_feature = target_feature.cuda() if use_cuda else target_feature
                target_feature_gram = Gram(target_feature)
                styled_loss = StyledLoss(target_feature_gram, style_loss_weight)
                styled_loss = styled_loss.cuda() if use_cuda else styled_loss
                model.add_module('styled_loss_' + str(i), styled_loss)
                style_losses.append(styled_loss)
        if isinstance(layer, nn.ReLU):
            name = 'relu_' + str(i)
            model.add_module(name, layer)
            i += 1
        if isinstance(layer, nn.MaxPool2d):
            name = 'pool_' + str(i)
            model.add_module(name, layer)

4.4 训练

    random_input_img = torch.randn(origin_image.data.size()).requires_grad_(True)
    if use_cuda:
        random_input_img = random_input_img.cuda()
        origin_image = origin_image.cuda()
        styled_image = styled_image.cuda()
    pyplot.figure()
    img_show(random_input_img.data, title='input image')
    # train begin
    input_param = nn.Parameter(random_input_img.data)
    optimizer = optim.LBFGS([input_param])
    epoch_num = 30
    for epoch in range(epoch_num):
        input_param.data.clamp_(0, 1)
        optimizer.zero_grad()
        model(input_param)
        style_score, content_score = 0, 0
        for loss in style_losses:
            style_score += loss.backward()
        for loss in origin_losses:
            content_score += loss.backward()
        if epoch % 10 == 0:
            print(f'no {epoch} train, style loss:{style_score}, content loss:{content_score}')
        def closure():
            return style_score + content_score
        optimizer.step(closure)
    output = input_param.data.clamp_(0, 1)
    pyplot.figure()
    img_show(output, title='output image')
    pyplot.ioff()
    pyplot.show()

用l-bfgs优化是因为擅长处理大规模梯度下降

5 小结

思考 咋学的

你可能感兴趣的:(pytorch,人工智能,python)