本文主要改动自:https://github.com/sowmyay/medium/blob/master/CV-LossFunctions.ipynb
首先回顾下特征损失(Feature loss)或者感知损失(Perceptual Loss)的初衷:
许多损失函数,如L1 loss、L2 loss、BCE loss,他们都是通过逐像素比较差异,从而对误差进行计算。然而,有的时候看起来十分相似的两个图像(比如图A相对于图B只是整体移动了一个像素),此时对人来说是几乎看不出区别的,但是其像素级损失(pixel-wise loss)将会变的巨大。对于这种任务就不能简单地使用底层的像素损失了,需要设计一种损失来学习语义差异。
既然要比较语义差异,那我们就需要首先获得一张图像的高层特征,而这就可以通过输出卷积神经网络的前几层的输出来实现,他们提取的就是高层的特征。
也就是说,给定两张图,我们不直接比较他们的像素级差异,而是均将他们放入同一网络中,获取某一中间层的输出特征图,然后再用一些传统的loss计算特征图之间的差异即可。在Perceptual Losses for Real-Time Style Transfer and Super-Resolution一文中使用的网络是VGG16,也可以使用一些其他的预训练深度网络(如ResNet, GoogLeNet,VGG19),不过一般VGG16的效果最好。
代码如下,这里使用了MSE来计算特征图的loss。
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision.models import vgg16_bn
class FeatureLoss(nn.Module):
def __init__(self, loss, blocks, weights, device):
super().__init__()
self.feature_loss = loss
assert all(isinstance(w, (int, float)) for w in weights)
assert len(weights) == len(blocks)
self.weights = torch.tensor(weights).to(device)
#VGG16 contains 5 blocks - 3 convolutions per block and 3 dense layers towards the end
assert len(blocks) <= 5
assert all(i in range(5) for i in blocks)
assert sorted(blocks) == blocks
vgg = vgg16_bn(pretrained=True).features
vgg.eval()
for param in vgg.parameters():
param.requires_grad = False
vgg = vgg.to(device)
bns = [i - 2 for i, m in enumerate(vgg) if isinstance(m, nn.MaxPool2d)]
assert all(isinstance(vgg[bn], nn.BatchNorm2d) for bn in bns)
self.hooks = [FeatureHook(vgg[bns[i]]) for i in blocks]
self.features = vgg[0: bns[blocks[-1]] + 1]
def forward(self, inputs, targets):
# normalize foreground pixels to ImageNet statistics for pre-trained VGG
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
inputs = F.normalize(inputs, mean, std)
targets = F.normalize(targets, mean, std)
# extract feature maps
self.features(inputs)
input_features = [hook.features.clone() for hook in self.hooks]
self.features(targets)
target_features = [hook.features for hook in self.hooks]
loss = 0.0
# compare their weighted loss
for lhs, rhs, w in zip(input_features, target_features, self.weights):
lhs = lhs.view(lhs.size(0), -1)
rhs = rhs.view(rhs.size(0), -1)
loss += self.feature_loss(lhs, rhs) * w
return loss
class FeatureHook:
def __init__(self, module):
self.features = None
self.hook = module.register_forward_hook(self.on)
def on(self, module, inputs, outputs):
self.features = outputs
def close(self):
self.hook.remove()
def perceptual_loss(x, y):
F.mse_loss(x, y)
def PerceptualLoss(blocks, weights, device):
return FeatureLoss(perceptual_loss, blocks, weights, device)
参数:
首先是导入torchvision中的vgg16,利用eval和requires_grad=False将权重冻结,方便我们输出特征图:
vgg = vgg16_bn(pretrained=True).features
vgg.eval()
for param in vgg.parameters():
param.requires_grad = False
vgg = vgg.to(device)
接着是取出vgg16中五个块的输出。这五个块都以max pool结尾,但是考虑到max pool层以及其上的relu层对比较特征图没有帮助(存疑),因此这里取出的是max pool前两层的batch norm层作为五个块的输出:
bns = [i - 2 for i, m in enumerate(vgg) if isinstance(m, nn.MaxPool2d)]
然后,对于我们指定的blocks(需要取出哪几层的输出),将相应bn层使用register_forward_hook方法来获取其输出:
self.hooks = [FeatureHook(vgg[bns[i]]) for i in blocks]
features其实就是一个精简的vgg16。我们需要哪几层的输出,就保留这几层之前的结构。如果我们只需要前两块的输出,那么后面三块其实就可以去掉了,减少运算量。
self.features = vgg[0: bns[blocks[-1]] + 1]
最后,将input和target输入网络,利用hook提取出特征图,对这些特征图进行对比,即可求解feature loss:
self.features(inputs)
input_features = [hook.features.clone() for hook in self.hooks]
self.features(targets)
target_features = [hook.features for hook in self.hooks]