pytorch 实现 Style Transfer

pytorch 实现 Style Transfer


设CNN中第 l l l 层风格图片、内容图片、生成图片的feature map分别为 S l S^l Sl C l C^l Cl G l G^l Gl

l l l层内容损失函数定义为 L C l = ∑ i j ( C i j l − G i j l ) 2 L_C^l=\sum\limits_{ij}(C^l_{ij}-G^l_{ij})^2 LCl=ij(CijlGijl)2

对于 feature map F F F,定义 F ( k k ′ ) F^{(kk')} F(kk) 为第 k k k通道和第 k ′ k' k通道 feature map 的内积,则第 l l l层风格损失函数定义为 L S l = 1 ( n H n W n C ) 2 ∑ k k ′ ( S l ( k k ′ ) − G l ( k k ′ ) ) 2 L_S^l=\frac{1}{(n_Hn_Wn_C)^2}\sum\limits_{kk'}(S^{l(kk')}-G^{l(kk')})^2 LSl=(nHnWnC)21kk(Sl(kk)Gl(kk))2

CNN框架采用VGG19,梯度下降采用L-BFGS。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 512 if torch.cuda.is_available() else 128

transform = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(imsize),
    transforms.ToTensor()])

def image_loader(image_name):
    image = Image.open(image_name)
    image = image.convert('RGB')
    image = transform(image).unsqueeze(0)
    return image.to(device, torch.float)

style_img = image_loader("./data/images/the_starry_night.jpg")
content_img = image_loader("./data/images/STJU.jpg")
assert style_img.size() == content_img.size()

def trans(tensor):
    image = tensor.cpu().clone() 
    image = image.squeeze(0)     
    image = image.numpy()
    image = image.transpose((1, 2, 0))
    return image

class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()
    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

def gram_matrix(input):
    a, b, c, d = input.size()

    features = input.view(a * b, c * d)
    G = torch.mm(features, features.t())
    return G.div(a * b * c * d)

class StyleLoss(nn.Module):
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

cnn = models.vgg19(pretrained=True).features.to(device).eval()

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.view(-1, 1, 1)
        self.std = std.view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img, content_layers=content_layers_default, style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)
    normalization = Normalization(normalization_mean, normalization_std).to(device)

    content_losses = []
    style_losses = []

    model = nn.Sequential(normalization)

    i = 0
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecogniced layer: {}'.format(layer/__class__.__name__))

        model.add_module(name, layer)

        if name in content_layers:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)
    
    for i in range(len(model)-1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break
    
    model = model[:(i + 1)]

    return model, style_losses, content_losses

input_img = content_img.clone()


def get_input_optimizer(input_img):
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer

def run_style_tranfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img, num_steps=300, style_weight=10000000000, content_weight=1):
    print('Building the style transfer model..')

    model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img)

    optimizer = get_input_optimizer(input_img)

    print('Optimizing..')
    run = [0]
    while run[0] < num_steps:
        
        def closure():
            input_img.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0.0
            content_score = 0.0

            for sl in style_losses:
                style_score +=sl.loss
            for cl in content_losses:
                content_score +=cl.loss
            
            style_score *= style_weight
            content_score *= content_weight
            
            loss = style_score + content_score
            loss.backward()

            run[0] += 1
            print(run[0])
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score
        
        optimizer.step(closure)

    input_img.data.clamp_(0, 1)

    return input_img

output = run_style_tranfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img)

style_img = trans(style_img)
content_img = trans(content_img)
output = trans(output.detach())

f, ax = plt.subplots(1, 3)
ax[0].imshow(style_img)
ax[0].set_title('Style Image')
ax[1].imshow(content_img)
ax[1].set_title('Content Image')
ax[2].imshow(output)
ax[2].set_title('Transposed Image')
plt.show()

迭代300次效果如下:
pytorch 实现 Style Transfer_第1张图片

你可能感兴趣的:(pytorch 实现 Style Transfer)