还记得 2019 年 07 月 07 日,当时在参加一个比赛需要用到云服务器,由于第一次接触所以在网上找了很多云服务器的注册和配置教程,但是发现很多分享都写的不够全面,于是萌生了把整个配置过程撰写成一篇博客分享出来的想法。自此我的第 1 篇技术博客:《滴滴云服务器的注册与快速配置教程》诞生了,而也正是从那开始,我开始了我的分享之旅。
在创作的过程我认为最大的收获便是培养了自己对知识的梳理和归纳能力,同时将博客作为的平时的随记,也方便了后期的查看。
经过这几年的分享,我截止目前已经收获了6800多名粉丝的关注,总访问量也是突破了35w,更是认识了许多志同道合的朋友。
目前创作已是学习中的一部分了,很多时候我会将想要分享的内容随记在初稿里,等到有空时会整理再一起发,所以现在更新频率更多是不定时的。
博客《【对比学习】CUT模型论文解读与NCE loss代码解析》中的代码是我感觉写的最好的代码,因为花了不少时间去理解论文和代码,才写出该总结博客。
# 一些基础参数赋值
batch_size=2
image_size=512
netF='mlp_sample' # 对应特征提取的Hl模块
netF_nc=256 # mlp层输出的维度大小
nce_T=0.07 # NCE loss的温度系数
num_patches=512 # 计算NCE loss时每一层采样点的数量
nce_layers='0,4,8,12,16' #计算NCE loss的层序号
nce_includes_all_negatives_from_minibatch=False # 该参数为True代表在计算负样本时,负样本字典应包含batch里的其他图片,在执行当单图片转换是才会赋值True。对于CUT和FastCUT任务默认为False
# 生成loss的定义
def compute_G_loss(self):
"""Calculate GAN and NCE loss for the generator"""
fake = self.fake_B
# First, G(A) should fake the discriminator
if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD(fake)
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
else:
self.loss_G_GAN = 0.0
if self.opt.lambda_NCE > 0.0:
self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
else:
self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
else:
loss_NCE_both = self.loss_NCE
self.loss_G = self.loss_G_GAN + loss_NCE_both
return self.loss_G
# 计算NCE loss
def calculate_NCE_loss(self, src, tgt):
n_layers = len(self.nce_layers) # n_layers=5
# 提取编码器中对应的5层特征,输出的feat_q的形式为:list[[2,1,518,518], [2,64,512,512], [2,128,256,256], [2,128,128,128], [2,128,128,128]]。
# list里面每个元素的维度为[batches,channels,heights,weights],输入图像默认大小是512*512,而第一个元素为518是因为在数据处理的时候做了padding
feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
# 只有在FastCUT模式才会做此强制翻转作为额外的正则化
if self.opt.flip_equivariance and self.flipped_for_equivariance:
feat_q = [torch.flip(fq, [3]) for fq in feat_q]
# 同样feat_k的形式为list[[2,1,518,518], [2,64,512,512], [2,128,256,256], [2,128,128,128], [2,128,128,128]]
feat_k = self.netG(src, self.nce_layers, encode_only=True)
# 通过MLP层提取特征和选取采样点,首先在k中随机采样num_patches=512个样本点,并返回采样点对应的ids
feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
# q也是经过MLP层提取特征,并选取和k对应ids的采样点
feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
# 计算NCE loss
total_nce_loss = 0.0
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
loss = crit(f_q, f_k) * self.opt.lambda_NCE
total_nce_loss += loss.mean()
return total_nce_loss / n_layers
# 采样前经过的MLP层提取特征,此处netF选用PatchSampleF,以下是PatchSampleF的定义
class PatchSampleF(nn.Module):
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
# potential issues: currently, we use the same patch_ids for multiple images in the batch
super(PatchSampleF, self).__init__()
self.l2norm = Normalize(2)
self.use_mlp = use_mlp
self.nc = nc # hard-coded
self.mlp_init = False
self.init_type = init_type
self.init_gain = init_gain
self.gpu_ids = gpu_ids
def create_mlp(self, feats): # 创建MLP层结构
for mlp_id, feat in enumerate(feats):
input_nc = feat.shape[1]
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
if len(self.gpu_ids) > 0:
mlp.cuda()
setattr(self, 'mlp_%d' % mlp_id, mlp)
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
self.mlp_init = True
def forward(self, feats, num_patches=512, patch_ids=None):
return_ids = []
return_feats = []
if self.use_mlp and not self.mlp_init:
self.create_mlp(feats)
for feat_id, feat in enumerate(feats): # 此处feats的形式为list[[2,1,518,518], [2,64,512,512], [2,128,256,256], [2,128,128,128], [2,128,128,128]]
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3] # B=2,H和W为不同层特征图的大小
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) # [B,C,H,W]——>[B,H,W,C]——>[B,HW,C]
if num_patches > 0:
if patch_ids is not None: # 对于feat_q_pool,因为此时传入了feat_k_pool采样到采样点对应ids
patch_id = patch_ids[feat_id]
else: # 一开始feat_k_pool是没有采样点id传入的,所以需要先随机选取采样点;后面feat_q_pool根据feat_k_pool得到的采样点进行对应查询采样
patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device) # 打乱feat_reshape中HW维度的顺序
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # 选择打乱后feat_reshape的前num_patches个点,作为采样点的ids
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # 获取对应采样点的特征,此刻x_sample的维度为[B,num_patches,C]——>[BXnum_patches, C]=[2x512, 256]
# 【注意:此处采样点的数量是num_patches=512,这512个采样点的id是不连续的,也就是说随机在特征图里采样512个点,对这随机采样的512个点做对比学习,这和论文中画出patch作为采样块有些不同】
else:
x_sample = feat_reshape
patch_id = []
if self.use_mlp:
mlp = getattr(self, 'mlp_%d' % feat_id)
x_sample = mlp(x_sample)
return_ids.append(patch_id)
x_sample = self.l2norm(x_sample)
if num_patches == 0:
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
return_feats.append(x_sample)
# 此时返回的return_feats的形式为list[[1024, 256],[1024, 256],[1024, 256],[1024, 256],[1024, 256]]
# return_ids返回每一层对应的采样点id号,[[512],[512],[512],[512],[512]]
return return_feats, return_ids
# NCE loss的定义
class PatchNCELoss(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
def forward(self, feat_q, feat_k):
batchSize = feat_q.shape[0] # batchSize=1024,batch上所有的采样点
dim = feat_q.shape[1] # dim=256,每个采样点的特征维度大小
feat_k = feat_k.detach()
# pos logit
# 变换后feat_q的维度变为[1024,1,256], feat_k的维度变为[1024,256,1].进行矩阵乘法后得到l_pos的维度为[1024,1,1]
# 此操作可以理解为feat_q与feat_k一 一对应的位置是相同的类别也就是正样本,因此feat_q与feat_k对应位置的矩阵乘法相当于求q与k+之间的相关性,也就是正样本之间相关性系数。
# 而batchSize=2x512是因为这是对应位置的矩阵相乘,因此可以将不同patch的采样点合并计算
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
l_pos = l_pos.view(batchSize, 1) # [1024,1,1]——>[1024,1]
# neg logit
# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives from the same image. Therefore, we set ‘--nce_includes_all_negatives_from_minibatch’ as False
# However, for single-image translation, the minibatch consists of crops from the "same" high-resolution image.
# Therefore, we will include the negatives from the entire minibatch.
if self.opt.nce_includes_all_negatives_from_minibatch:
# reshape features as if they are all negatives of minibatch of size 1.
batch_dim_for_bmm = 1
else:
batch_dim_for_bmm = self.opt.batch_size # batch_dim_for_bmm=2
# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) # [1024, 256]——>[2,512,256]
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) # [1024, 256]——>[2,512,256]
npatches = feat_q.size(1) # npatches=512
# feat_q的维度为[2,512,256],feat_k变换后的维度为[2,256,512].对feat_q和feat_k进行矩阵相乘得到的是q中每个采样点与k中每个采样点的相关性矩阵(类似混淆矩阵)大小是512x512,结果l_neg_curbatch的维度为[2,512,512]
# 【注意:为什么此处不将不同batch的样本合并来获得更大的负样本?】
# (1)作者提到在FastCUT和CUT模式中,仅使用同一张图像的采样点作为负样本点结果比使用不同图像的结果更好。
# 至于计算l_pos可以这么做的原因是l_pos计算的是对应位置的采样点,是一一对应的,所以l_pos合并不同图像计算与不合并没有差别。
# (2)除此之外,我觉得另外一个原因是:合并不同图像的负样本点的计算量开销和内存消耗远远大于不合并。
# 合并后将变成[1,1024,256]@[1,256,1024](计算量:1024x256x1024,保存的矩阵大小1024x1024)
# 而不合并是[2,512,256]@[2,256,512](计算量:2x512x256x512,保存的矩阵大小为2x512x512)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) # 输出维度[2,512,512]
# diagonal entries are similarity between same features, and hence meaningless.
# just fill the diagonal with very small number, which is exp(-10) and almost zero
# 由于对角线计算的是相同采样位置的相似性,也就是计算的q@k+。所以计算负样本的时候要把对角线的值变成0,这样得到的矩阵才是真正意义上的q@k-。
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :] #生成对角线为1,其他元素均是0,就是大小为512x512的对角矩阵
l_neg_curbatch.masked_fill_(diagonal, -10.0) # 因为对角线是相同采样位置之间的相关性,所以-10操作相当于将其置0,得到负样本的相关性矩阵
l_neg = l_neg_curbatch.view(-1, npatches) # 输出维度为[1024,512]
# 合并正负样本的logits,输出维度为[1024,513],1024可以理解成数据量大小,513可以理解成有513个类别,其中正样本的类别序号是0.
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
# 此时NCE loss变成了513类别的分类问题,只有0类别是正类,所以变成可以设置成全零。进行交叉熵损失计算后的结果就是NCE loss的结果
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
device=feat_q.device))
return loss
希望未来能够继续保持创作的热情与动力。