想借助一个预训练好的网络(非集成好的)计算feature-loss,预训练网络地址:表情识别net
作者已经给出了预训练好的模型参数和模型代码,首先我们要把模型load进来:
from Expression.VGG import VGG
model = VGG('VGG19')
#check_pth 从网站上download下来PrivateTest_model.t7
checkpoint = torch.load(check_pth)
model.load_state_dict(checkpoint['net'])
model.cuda()
model.eval()
我们可以看一下该模型的结构
print(torch.nn.Sequential(*list(model.children())[:]))
结果为:
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): ReLU(inplace)
(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): ReLU(inplace)
(13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): ReLU(inplace)
(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(19): ReLU(inplace)
(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(22): ReLU(inplace)
(23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(25): ReLU(inplace)
(26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(29): ReLU(inplace)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(32): ReLU(inplace)
(33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(35): ReLU(inplace)
(36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(38): ReLU(inplace)
(39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(42): ReLU(inplace)
(43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(45): ReLU(inplace)
(46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(48): ReLU(inplace)
(49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(51): ReLU(inplace)
(52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(53): AvgPool2d(kernel_size=1, stride=1, padding=0)
)
(1): Linear(in_features=512, out_features=7, bias=True)
)
class FeatureExtractor(nn.Module):
def __init__(self, model, feature_layer=50, device=torch.device('cpu')):
super(FeatureExtractor, self).__init__()
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
# Assume input range is [0, 1]
output = self.features(x)
return output
def define_E(opt, check_pth='Expression/FER2013_VGG19/PrivateTest_model.t7'):
from Expression.VGG import VGG, FeatureExtractor
netE = VGG('VGG19')
checkpoint = torch.load(check_pth)
netE.load_state_dict(checkpoint['net'])
netE.cuda()
netE.eval()
extract_E = FeatureExtractor(model=netE, feature_layer=50)
return extract_E.eval()
def loss(train_opt, device):
if train_opt.expression_weight > 0:
l_exp_type = train_opt.expression_criterion
if l_exp_type == 'l1':
cri_exp = nn.L1Loss().to(device)
elif l_exp_type == 'l2':
cri_exp = nn.MSELoss().to(device)
else:
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_exp_type))
l_exp_w = train_opt.expression_weight
else:
logger.info('Remove expression loss.')
cri_exp = None
if cri_exp:
netE = define_E(train_opt).to(device)
netE = DataParallel(netE)
l_g_exp = 0
real_exp = netE(real_img).detach()
fake_exp = netE(fake_img)
l_g_exp += l_exp_w * cri_exp(fake_exp, real_exp)