代码链接:GitHub - taylover-pei/SSDG-CVPR2020: Single-Side Domain Generalization for Face Anti-Spoofing, CVPR2020
SSDG 模型整体结构图如下:
读取三个不同源域的数据
import os
import random
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils.dataset import YunpeiDataset
from utils.utils import sample_frames
def get_dataset(src1_data, src1_train_num_frames, src2_data, src2_train_num_frames, src3_data, src3_train_num_frames,
tgt_data, tgt_test_num_frames, batch_size):
print('Load Source Data')
print('Source Data: ', src1_data)
src1_train_data_fake = sample_frames(flag=0, num_frames=src1_train_num_frames, dataset_name=src1_data)
src1_train_data_real = sample_frames(flag=1, num_frames=src1_train_num_frames, dataset_name=src1_data)
print('Source Data: ', src2_data)
src2_train_data_fake = sample_frames(flag=0, num_frames=src2_train_num_frames, dataset_name=src2_data)
src2_train_data_real = sample_frames(flag=1, num_frames=src2_train_num_frames, dataset_name=src2_data)
print('Source Data: ', src3_data)
src3_train_data_fake = sample_frames(flag=0, num_frames=src3_train_num_frames, dataset_name=src3_data)
src3_train_data_real = sample_frames(flag=1, num_frames=src3_train_num_frames, dataset_name=src3_data)
print('Load Target Data')
print('Target Data: ', tgt_data)
tgt_test_data = sample_frames(flag=2, num_frames=tgt_test_num_frames, dataset_name=tgt_data)
src1_train_dataloader_fake = DataLoader(YunpeiDataset(src1_train_data_fake, train=True),
batch_size=batch_size, shuffle=True)
src1_train_dataloader_real = DataLoader(YunpeiDataset(src1_train_data_real, train=True),
batch_size=batch_size, shuffle=True)
src2_train_dataloader_fake = DataLoader(YunpeiDataset(src2_train_data_fake, train=True),
batch_size=batch_size, shuffle=True)
src2_train_dataloader_real = DataLoader(YunpeiDataset(src2_train_data_real, train=True),
batch_size=batch_size, shuffle=True)
src3_train_dataloader_fake = DataLoader(YunpeiDataset(src3_train_data_fake, train=True),
batch_size=batch_size, shuffle=True)
src3_train_dataloader_real = DataLoader(YunpeiDataset(src3_train_data_real, train=True),
batch_size=batch_size, shuffle=True)
tgt_dataloader = DataLoader(YunpeiDataset(tgt_test_data, train=False), batch_size=batch_size, shuffle=False)
return src1_train_dataloader_fake, src1_train_dataloader_real, \
src2_train_dataloader_fake, src2_train_dataloader_real, \
src3_train_dataloader_fake, src3_train_dataloader_real, \
tgt_dataloader
######### data prepare #########
src1_img_real, src1_label_real = src1_train_iter_real.next()
src1_img_real = src1_img_real.cuda()
src1_label_real = src1_label_real.cuda()
input1_real_shape = src1_img_real.shape[0]
src2_img_real, src2_label_real = src2_train_iter_real.next()
src2_img_real = src2_img_real.cuda()
src2_label_real = src2_label_real.cuda()
input2_real_shape = src2_img_real.shape[0]
src3_img_real, src3_label_real = src3_train_iter_real.next()
src3_img_real = src3_img_real.cuda()
src3_label_real = src3_label_real.cuda()
input3_real_shape = src3_img_real.shape[0]
src1_img_fake, src1_label_fake = src1_train_iter_fake.next()
src1_img_fake = src1_img_fake.cuda()
src1_label_fake = src1_label_fake.cuda()
input1_fake_shape = src1_img_fake.shape[0]
src2_img_fake, src2_label_fake = src2_train_iter_fake.next()
src2_img_fake = src2_img_fake.cuda()
src2_label_fake = src2_label_fake.cuda()
input2_fake_shape = src2_img_fake.shape[0]
src3_img_fake, src3_label_fake = src3_train_iter_fake.next()
src3_img_fake = src3_img_fake.cuda()
src3_label_fake = src3_label_fake.cuda()
input3_fake_shape = src3_img_fake.shape[0]
input_data = torch.cat([src1_img_real, src1_img_fake, src2_img_real, src2_img_fake, src3_img_real, src3_img_fake], dim=0)
source_label = torch.cat([src1_label_real, src1_label_fake,
src2_label_real, src2_label_fake,
src3_label_real, src3_label_fake], dim=0)
调用 DG_model
import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, BasicBlock
import sys
import numpy as np
from torch.autograd import Variable
import random
import os
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
# change your path
model_path = r'D:\Projects\face_anti_spoofing\SSDG-CVPR2020-master\pretrained_model\resnet18-5c106cde.pth'
if pretrained:
model.load_state_dict(torch.load(model_path))
print("loading model: ", model_path)
# print(model)
return model
class Feature_Generator_ResNet18(nn.Module):
def __init__(self):
super(Feature_Generator_ResNet18, self).__init__()
model_resnet = resnet18(pretrained=True)
self.conv1 = model_resnet.conv1
self.bn1 = model_resnet.bn1
self.relu = model_resnet.relu
self.maxpool = model_resnet.maxpool
self.layer1 = model_resnet.layer1
self.layer2 = model_resnet.layer2
self.layer3 = model_resnet.layer3
def forward(self, input):
feature = self.conv1(input)
feature = self.bn1(feature)
feature = self.relu(feature)
feature = self.maxpool(feature)
feature = self.layer1(feature)
feature = self.layer2(feature)
feature = self.layer3(feature)
return feature
class Feature_Embedder_ResNet18(nn.Module):
def __init__(self):
super(Feature_Embedder_ResNet18, self).__init__()
model_resnet = resnet18(pretrained=False)
self.layer4 = model_resnet.layer4
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
self.bottleneck_layer_fc = nn.Linear(512, 512)
self.bottleneck_layer_fc.weight.data.normal_(0, 0.005)
self.bottleneck_layer_fc.bias.data.fill_(0.1)
self.bottleneck_layer = nn.Sequential(
self.bottleneck_layer_fc,
nn.ReLU(),
nn.Dropout(0.5)
)
def forward(self, input, norm_flag):
feature = self.layer4(input)
feature = self.avgpool(feature)
feature = feature.view(feature.size(0), -1)
feature = self.bottleneck_layer(feature)
if (norm_flag):
feature_norm = torch.norm(feature, p=2, dim=1, keepdim=True).clamp(min=1e-12) ** 0.5 * (2) ** 0.5
feature = torch.div(feature, feature_norm)
return feature
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.classifier_layer = nn.Linear(512, 2)
self.classifier_layer.weight.data.normal_(0, 0.01)
self.classifier_layer.bias.data.fill_(0.0)
def forward(self, input, norm_flag):
if(norm_flag):
self.classifier_layer.weight.data = l2_norm(self.classifier_layer.weight, axis=0)
classifier_out = self.classifier_layer(input)
else:
classifier_out = self.classifier_layer(input)
return classifier_out
class DG_model(nn.Module):
def __init__(self, model):
super(DG_model, self).__init__()
if(model == 'resnet18'):
self.backbone = Feature_Generator_ResNet18()
self.embedder = Feature_Embedder_ResNet18()
elif(model == 'maddg'):
self.backbone = Feature_Generator_MADDG()
self.embedder = Feature_Embedder_MADDG()
else:
print('Wrong Name!')
self.classifier = Classifier()
def forward(self, input, norm_flag):
feature = self.backbone(input)
feature = self.embedder(feature, norm_flag)
classifier_out = self.classifier(feature, norm_flag)
print(feature.shape)
return classifier_out, feature
实例化网络:
model = DG_model('resnet18')
classifier_label_out, feature = net(input_data, config.norm_flag)
在判别器的反向传播中引入 GRL,作用是在训练早期阶段抑制噪声信号的影响。在训练初期,GRL 的系数很小,随着迭代次数的增加,系数逐渐增大
import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, BasicBlock
import sys
import numpy as np
from torch.autograd import Variable
import random
import os
class GRL(torch.autograd.Function):
def __init__(self):
self.iter_num = 0
self.alpha = 10
self.low = 0.0
self.high = 1.0
self.max_iter = 4000 # be same to the max_iter of config.py
def forward(self, input):
self.iter_num += 1
return input * 1.0
def backward(self, gradOutput):
coeff = np.float(2.0 * (self.high - self.low) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iter))
- (self.high - self.low) + self.low)
return -coeff * gradOutput
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(512, 512)
self.fc1.weight.data.normal_(0, 0.01)
self.fc1.bias.data.fill_(0.0)
self.fc2 = nn.Linear(512, 3)
self.fc2.weight.data.normal_(0, 0.3)
self.fc2.bias.data.fill_(0.0)
self.ad_net = nn.Sequential(
self.fc1,
nn.ReLU(),
nn.Dropout(0.5),
self.fc2
)
self.grl_layer = GRL()
def forward(self, feature):
adversarial_out = self.ad_net(self.grl_layer(feature))
return adversarial_out
实例化判别器网络
ad_net_real = Discriminator().to(device)
feature 为 DG_model 的第二个输出 (DG_model 有两个输出,一个是分类结果,一个是经backbone 和 embedder 网络提取的特征)
######### single side adversarial learning #########
input1_shape = input1_real_shape + input1_fake_shape
input2_shape = input2_real_shape + input2_fake_shape
# torch.narrow(input,dim,start,length),从input张量中返回一个范围限制后的张量,范围限制条件为:沿维度dim从start到start+length的范围区间,类似于数组切片用法
# 从feature中选出不同域的真人脸特征并拼接起来
feature_real_1 = feature.narrow(0, 0, input1_real_shape)
feature_real_2 = feature.narrow(0, input1_shape, input2_real_shape)
feature_real_3 = feature.narrow(0, input1_shape+input2_shape, input3_real_shape)
feature_real = torch.cat([feature_real_1, feature_real_2, feature_real_3], dim=0)
discriminator_out_real = ad_net_real(feature_real)
######### unbalanced triplet loss #########
real_domain_label_1 = torch.LongTensor(input1_real_shape, 1).fill_(0).cuda()
real_domain_label_2 = torch.LongTensor(input2_real_shape, 1).fill_(0).cuda()
real_domain_label_3 = torch.LongTensor(input3_real_shape, 1).fill_(0).cuda()
fake_domain_label_1 = torch.LongTensor(input1_fake_shape, 1).fill_(1).cuda()
fake_domain_label_2 = torch.LongTensor(input2_fake_shape, 1).fill_(2).cuda()
fake_domain_label_3 = torch.LongTensor(input3_fake_shape, 1).fill_(3).cuda()
source_domain_label = torch.cat([real_domain_label_1, fake_domain_label_1,
real_domain_label_2, fake_domain_label_2,
real_domain_label_3, fake_domain_label_3], dim=0).view(-1)
triplet = criterion["triplet"](feature, source_domain_label)
######### cross-entropy loss #########
real_shape_list = []
real_shape_list.append(input1_real_shape)
real_shape_list.append(input2_real_shape)
real_shape_list.append(input3_real_shape)
real_adloss = Real_AdLoss(discriminator_out_real, criterion["softmax"], real_shape_list)
cls_loss = criterion["softmax"](classifier_label_out.narrow(0, 0, input_data.size(0)), source_label)
total_loss = cls_loss + config.lambda_triplet * triplet + config.lambda_adreal * real_adloss
I_C_M_to_O 协议下的配置文件如下:
class DefaultConfigs(object):
seed = 666
# SGD
weight_decay = 5e-4
momentum = 0.9
# learning rate
init_lr = 0.01
lr_epoch_1 = 0
lr_epoch_2 = 150
# model
pretrained = True
model = 'resnet18' # resnet18 or maddg
# training parameters
gpus = "3"
batch_size = 10
norm_flag = True
max_iter = 4000
lambda_triplet = 2
lambda_adreal = 0.1
# test model name
tgt_best_model_name = 'model_best_0.08_29.pth.tar'
# source data information
src1_data = 'casia'
src1_train_num_frames = 1
src2_data = 'replay'
src2_train_num_frames = 1
src3_data = 'msu'
src3_train_num_frames = 1
# target data information
tgt_data = 'oulu'
tgt_test_num_frames = 2
# paths information
checkpoint_path = './' + tgt_data + '_checkpoint/' + model + '/DGFANet/'
best_model_path = './' + tgt_data + '_checkpoint/' + model + '/best_model/'
logs = './logs/'
config = DefaultConfigs()