Pytorch Feature loss与Perceptual Loss的实现

本文主要改动自:https://github.com/sowmyay/medium/blob/master/CV-LossFunctions.ipynb

首先回顾下特征损失(Feature loss)或者感知损失(Perceptual Loss)的初衷:

许多损失函数,如L1 loss、L2 loss、BCE loss,他们都是通过逐像素比较差异,从而对误差进行计算。然而,有的时候看起来十分相似的两个图像(比如图A相对于图B只是整体移动了一个像素),此时对人来说是几乎看不出区别的,但是其像素级损失(pixel-wise loss)将会变的巨大。对于这种任务就不能简单地使用底层的像素损失了,需要设计一种损失来学习语义差异。

既然要比较语义差异,那我们就需要首先获得一张图像的高层特征,而这就可以通过输出卷积神经网络的前几层的输出来实现,他们提取的就是高层的特征。

也就是说,给定两张图,我们不直接比较他们的像素级差异,而是均将他们放入同一网络中,获取某一中间层的输出特征图,然后再用一些传统的loss计算特征图之间的差异即可。在Perceptual Losses for Real-Time Style Transfer and Super-Resolution一文中使用的网络是VGG16,也可以使用一些其他的预训练深度网络(如ResNet, GoogLeNet,VGG19),不过一般VGG16的效果最好。

代码如下,这里使用了MSE来计算特征图的loss。

import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision.models import vgg16_bn

class FeatureLoss(nn.Module):
    def __init__(self, loss, blocks, weights, device):
        super().__init__()
        self.feature_loss = loss
        assert all(isinstance(w, (int, float)) for w in weights)
        assert len(weights) == len(blocks)

        self.weights = torch.tensor(weights).to(device)
        #VGG16 contains 5 blocks - 3 convolutions per block and 3 dense layers towards the end
        assert len(blocks) <= 5
        assert all(i in range(5) for i in blocks)
        assert sorted(blocks) == blocks

        vgg = vgg16_bn(pretrained=True).features
        vgg.eval()

        for param in vgg.parameters():
            param.requires_grad = False

        vgg = vgg.to(device)

        bns = [i - 2 for i, m in enumerate(vgg) if isinstance(m, nn.MaxPool2d)]
        assert all(isinstance(vgg[bn], nn.BatchNorm2d) for bn in bns)

        self.hooks = [FeatureHook(vgg[bns[i]]) for i in blocks]
        self.features = vgg[0: bns[blocks[-1]] + 1]

    def forward(self, inputs, targets):

        # normalize foreground pixels to ImageNet statistics for pre-trained VGG
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        inputs = F.normalize(inputs, mean, std)
        targets = F.normalize(targets, mean, std)

        # extract feature maps
        self.features(inputs)
        input_features = [hook.features.clone() for hook in self.hooks]

        self.features(targets)
        target_features = [hook.features for hook in self.hooks]

        loss = 0.0
        
        # compare their weighted loss
        for lhs, rhs, w in zip(input_features, target_features, self.weights):
            lhs = lhs.view(lhs.size(0), -1)
            rhs = rhs.view(rhs.size(0), -1)
            loss += self.feature_loss(lhs, rhs) * w

        return loss

class FeatureHook:
    def __init__(self, module):
        self.features = None
        self.hook = module.register_forward_hook(self.on)

    def on(self, module, inputs, outputs):
        self.features = outputs

    def close(self):
        self.hook.remove()
        
def perceptual_loss(x, y):
    F.mse_loss(x, y)
    
def PerceptualLoss(blocks, weights, device):
    return FeatureLoss(perceptual_loss, blocks, weights, device)

参数:

  • blocks: 选取vgg的哪几块输出作为中间特征图,例如[0, 1, 2]选取前三块。
  • weights: 在计算最终loss时各个特征图loss的权重
  • device: 使用的设备,可以直接传入torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

关键代码分析

首先是导入torchvision中的vgg16,利用eval和requires_grad=False将权重冻结,方便我们输出特征图:

vgg = vgg16_bn(pretrained=True).features
vgg.eval()
for param in vgg.parameters():
    param.requires_grad = False
vgg = vgg.to(device)

接着是取出vgg16中五个块的输出。这五个块都以max pool结尾,但是考虑到max pool层以及其上的relu层对比较特征图没有帮助(存疑),因此这里取出的是max pool前两层的batch norm层作为五个块的输出:

bns = [i - 2 for i, m in enumerate(vgg) if isinstance(m, nn.MaxPool2d)]

然后,对于我们指定的blocks(需要取出哪几层的输出),将相应bn层使用register_forward_hook方法来获取其输出:

self.hooks = [FeatureHook(vgg[bns[i]]) for i in blocks]

features其实就是一个精简的vgg16。我们需要哪几层的输出,就保留这几层之前的结构。如果我们只需要前两块的输出,那么后面三块其实就可以去掉了,减少运算量。

self.features = vgg[0: bns[blocks[-1]] + 1]

最后,将input和target输入网络,利用hook提取出特征图,对这些特征图进行对比,即可求解feature loss:

self.features(inputs)
input_features = [hook.features.clone() for hook in self.hooks]

self.features(targets)
target_features = [hook.features for hook in self.hooks]

你可能感兴趣的:(Pytorch)