以下链接是个人关于DG-Net(行人重识别ReID)所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
行人重识别0-00:DG-GAN(ReID)-目录-史上最新最全:https://blog.csdn.net/weixin_43013761/article/details/102364512
首先根据上篇博客,说一下代码位置networks.py:
def calc_dis_loss(self, model, input_fake, input_real):
"""
该loss为了训练D,即鉴别器本身
:param model: 为自己本身MsImageDis
:param input_fake: 输入假图片,也就是合成的图片
:param input_real: 输入真图片,训练集里面的图片
:return:
"""
# calculate the loss to train D
input_real.requires_grad_()
# 这里一起3个元素,分别大小为[batch_size, 1,64,32], [batch_size, 1,32,16], [batch_size, 1,16,8]
outs0 = model.forward(input_fake)
# 这里一起3个元素,分别大小为[batch_size, 1,64,32], [batch_size, 1,32,16], [batch_size, 1,16,8]
outs1 = model.forward(input_real)
loss = 0
reg = 0
Drift = 0.001
LAMBDA = self.LAMBDA
# 默认gan_type = 'lsgan',即没有执行这里
if self.gan_type == 'wgan':
loss += torch.mean(outs0) - torch.mean(outs1)
# progressive gan
loss += Drift*( torch.sum(outs0**2) + torch.sum(outs1**2))
#alpha = torch.FloatTensor(input_fake.shape).uniform_(0., 1.)
#alpha = alpha.cuda()
#differences = input_fake - input_real
#interpolates = Variable(input_real + (alpha*differences), requires_grad=True)
#dis_interpolates = self.forward(interpolates)
#gradient_penalty = self.compute_grad2(dis_interpolates, interpolates).mean()
#reg += LAMBDA*gradient_penalty
reg += LAMBDA* self.compute_grad2(outs1, input_real).mean() # I suggest Lambda=0.1 for wgan
loss = loss + reg
return loss, reg
for it, (out0, out1) in enumerate(zip(outs0, outs1)):
# 默认gan_type == 'lsgan',最小二乘损失方式,主要解决生成图像不稳定的问题
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
# regularization
reg += LAMBDA* self.compute_grad2(out1, input_real).mean()
elif self.gan_type == 'nsgan':
all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
F.binary_cross_entropy(F.sigmoid(out1), all1))
reg += LAMBDA* self.compute_grad2(F.sigmoid(out1), input_real).mean()
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
loss = loss+reg
return loss, reg
def calc_gen_loss(self, model, input_fake):
"""
:param model: 为自己本身MsImageDis
:param input_fake: 输入假的图片
:return:
"""
# calculate the loss to train G
# 生成图片,初一这里的输出还是有3个尺寸
outs0 = model.forward(input_fake)
loss = 0
Drift = 0.001
# 该处不执行,因为gan_type == 'lsgan'
if self.gan_type == 'wgan':
loss += -torch.mean(outs0)
# progressive gan
loss += Drift*torch.sum(outs0**2)
return loss
# 同理我们使用的是gan_type == 'lsgan'
for it, (out0) in enumerate(outs0):
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 1)**2) * 2 # LSGAN
elif self.gan_type == 'nsgan':
all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
return loss
# 计算梯度,大概反向传播使用,了解了细节的朋友麻烦告诉我下
def compute_grad2(self, d_out, x_in):
batch_size = x_in.size(0)
# 这是一个对输出自动求导数的函数,这里表示对outputs=d_out.sum()求inputs=x_in的导数
grad_dout = torch.autograd.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.size() == x_in.size())
reg = grad_dout2.view(batch_size, -1).sum(1)
return reg
主要讲解这三个函数,需要要注意的是,我们这里讲解的是鉴别器(MsImageDis)相关的损失函数,这个鉴别器,是鉴别图片的真假,不是对身份的鉴别。MsImageDis中MS表示的应该是对尺度的意思,dis表示鉴别。Image自己体会下。
首先,来说一个概念。GANS网路一般都会有两个模块,即生成模块和鉴别模块。在生成模块计算损失的时候,我们是不需要真实图片的。直白的说,就是把生成模块生成的图片,丢给鉴别器就可以了,通过鉴别器告诉生成模块是真的还是假的,如果是假的,生成模块就会继续优化。所以生成模块损失计算函数def calc_gen_loss(self, model, input_fake)只有一个,只需要一张假的图片就能够计算损失了。
但是鉴别模块的损失计算是不一样的,如def calc_dis_loss(self, model, input_fake, input_real),这里可以看到,其是有两个参数的,一个为假的照片,一个为真的照片。为什么呢?因为鉴别模块,不仅要认出假冒的图片,还要认出真实的图片。所以假的图片和真的图片都要让他学习(以后就不再提及这个概念了)。
LSGAN是一篇关于GAN网络论文提出的,如果我先去阅读这篇论文,然后再为大家详细介绍,就没有必要了,我相信你也不想听我啰嗦,同样我也不想去看那个论文,毕竟时间是宝贵的,我这里给大家简单介绍一下就行了:
首先是针对生成器G优化的LOSS计算(论文中得公式):
L ( G ) = ξ x − p x ( D ( G ( x ) ) − c ) 2 L(G)=\xi_{x-p_x} (D(G(x))-c)^2 L(G)=ξx−px(D(G(x))−c)2
给大家提一下,以后看这样的公式,不要想得太复杂了,如下面我们这样分析,找到我们源码中代码如下:
for it, (out0) in enumerate(outs0):
if self.gan_type == 'lsgan':
# 这里我们可以看到,当out0=1的时候,其损失是最小的,同时out0是假的图片经过odel.forward(input_fake)鉴别之后的输出
loss += torch.mean((out0 - 1)**2) * 2 # LSGAN
可以看到,和上面得公式就一一对应起来了。公式中的C等于源码中的1。所以我们从源码可以知道当(out0 - 1)=0的时候,loss是最小的,也就是ut0 = 1,我们从代码可以知道ut0是生成图片经过鉴别器的结果。结果为1,这就表示,鉴别器被生成器欺骗过去了,鉴别器把假冒生成的图片鉴别成了真实图片。这个时候生成器的模型的最好的。
所以同过上面loss的计算,对生成器进行优化了。
既然生成器G讲解完成了,我们来看看鉴别器又是怎么优化的。首先看论文公式:
L ( D ) = 1 2 ξ z − p z ( D ( G ( z ) ) − b ) 2 + 1 2 ξ x − p x ( D ( x ) − a ) 2 L(D)= \frac{1}{2} \xi_{z-p_z} (D(G(z))-b)^2 + \frac{1}{2} \xi_{x-p_x} (D(x)-a)^2 L(D)=21ξz−pz(D(G(z))−b)2+21ξx−px(D(x)−a)2
然后在找到对应的代码:
for it, (out0, out1) in enumerate(zip(outs0, outs1)):
# 默认gan_type == 'lsgan',最小二乘损失方式,主要解决生成图像不稳定的问题
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
可以看到loss的计算和上面的公式一一对应起来了,out0表示假的照片经过鉴别器得到的结果,,即out0= D ( G ( z ) ) D(G(z)) D(G(z)),out1表示真实的经过鉴别器得到的结果,也就是 o u t 1 = D ( x ) out1=D(x) out1=D(x),可以看到,我们如果想loss最小的,那么就是out0=0(假冒的图片),out1=1(真实的图片)。
这样就能通过loss的计算进行网络迭代,对图片鉴别器进行优化了。当真实图片输入为1,合成图片输出0的时候,测试鉴别效果是最好的。
到这里为止,我们基本把图片真假的鉴别器已经讲解完成了,下面我们回到trainer.py文件,然后找到教师网络的相关部分:
# 加载教师模型
# load teachers
# teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke”
if hyperparameters['teacher'] != "":
teacher_name = hyperparameters['teacher']
print(teacher_name)
# 有这个操作,我怀疑是可以加载多个教师模型
teacher_names = teacher_name.split(',')
# 构建教师模型
teacher_model = nn.ModuleList()
teacher_count = 0
# 默认只有一个teacher_name='teacher_name',所以其加载的模型为项目根目录models/best/opts.yaml模型
for teacher_name in teacher_names:
config_tmp = load_config(teacher_name)
# 默认stride=1,是池化层的stride
if 'stride' in config_tmp:
stride = config_tmp['stride']
else:
stride = 2
# 网络搭建
model_tmp = ft_net(ID_class, stride = stride)
teacher_model_tmp = load_network(model_tmp, teacher_name)
# 移除原本的全连接层
teacher_model_tmp.model.fc = nn.Sequential() # remove the original fc layer in ImageNet
# 应该是进行网络搭建
teacher_model_tmp = teacher_model_tmp.cuda()
#summary(teacher_model_tmp, (3, 224, 224))
# 使用浮点型
if self.fp16:
teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
teacher_model.append(teacher_model_tmp.cuda().eval())
teacher_count +=1
self.teacher_model = teacher_model
# 选择是否使用bn
if hyperparameters['train_bn']:
self.teacher_model = self.teacher_model.apply(train_bn)
下面是网络结构的打印:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 4,096
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 16,384
BatchNorm2d-12 [-1, 256, 56, 56] 512
Conv2d-13 [-1, 256, 56, 56] 16,384
BatchNorm2d-14 [-1, 256, 56, 56] 512
ReLU-15 [-1, 256, 56, 56] 0
Bottleneck-16 [-1, 256, 56, 56] 0
Conv2d-17 [-1, 64, 56, 56] 16,384
BatchNorm2d-18 [-1, 64, 56, 56] 128
ReLU-19 [-1, 64, 56, 56] 0
Conv2d-20 [-1, 64, 56, 56] 36,864
BatchNorm2d-21 [-1, 64, 56, 56] 128
ReLU-22 [-1, 64, 56, 56] 0
Conv2d-23 [-1, 256, 56, 56] 16,384
BatchNorm2d-24 [-1, 256, 56, 56] 512
ReLU-25 [-1, 256, 56, 56] 0
Bottleneck-26 [-1, 256, 56, 56] 0
Conv2d-27 [-1, 64, 56, 56] 16,384
BatchNorm2d-28 [-1, 64, 56, 56] 128
ReLU-29 [-1, 64, 56, 56] 0
Conv2d-30 [-1, 64, 56, 56] 36,864
BatchNorm2d-31 [-1, 64, 56, 56] 128
ReLU-32 [-1, 64, 56, 56] 0
Conv2d-33 [-1, 256, 56, 56] 16,384
BatchNorm2d-34 [-1, 256, 56, 56] 512
ReLU-35 [-1, 256, 56, 56] 0
Bottleneck-36 [-1, 256, 56, 56] 0
Conv2d-37 [-1, 128, 56, 56] 32,768
BatchNorm2d-38 [-1, 128, 56, 56] 256
ReLU-39 [-1, 128, 56, 56] 0
Conv2d-40 [-1, 128, 28, 28] 147,456
BatchNorm2d-41 [-1, 128, 28, 28] 256
ReLU-42 [-1, 128, 28, 28] 0
Conv2d-43 [-1, 512, 28, 28] 65,536
BatchNorm2d-44 [-1, 512, 28, 28] 1,024
Conv2d-45 [-1, 512, 28, 28] 131,072
BatchNorm2d-46 [-1, 512, 28, 28] 1,024
ReLU-47 [-1, 512, 28, 28] 0
Bottleneck-48 [-1, 512, 28, 28] 0
Conv2d-49 [-1, 128, 28, 28] 65,536
BatchNorm2d-50 [-1, 128, 28, 28] 256
ReLU-51 [-1, 128, 28, 28] 0
Conv2d-52 [-1, 128, 28, 28] 147,456
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
Conv2d-55 [-1, 512, 28, 28] 65,536
BatchNorm2d-56 [-1, 512, 28, 28] 1,024
ReLU-57 [-1, 512, 28, 28] 0
Bottleneck-58 [-1, 512, 28, 28] 0
Conv2d-59 [-1, 128, 28, 28] 65,536
BatchNorm2d-60 [-1, 128, 28, 28] 256
ReLU-61 [-1, 128, 28, 28] 0
Conv2d-62 [-1, 128, 28, 28] 147,456
BatchNorm2d-63 [-1, 128, 28, 28] 256
ReLU-64 [-1, 128, 28, 28] 0
Conv2d-65 [-1, 512, 28, 28] 65,536
BatchNorm2d-66 [-1, 512, 28, 28] 1,024
ReLU-67 [-1, 512, 28, 28] 0
Bottleneck-68 [-1, 512, 28, 28] 0
Conv2d-69 [-1, 128, 28, 28] 65,536
BatchNorm2d-70 [-1, 128, 28, 28] 256
ReLU-71 [-1, 128, 28, 28] 0
Conv2d-72 [-1, 128, 28, 28] 147,456
BatchNorm2d-73 [-1, 128, 28, 28] 256
ReLU-74 [-1, 128, 28, 28] 0
Conv2d-75 [-1, 512, 28, 28] 65,536
BatchNorm2d-76 [-1, 512, 28, 28] 1,024
ReLU-77 [-1, 512, 28, 28] 0
Bottleneck-78 [-1, 512, 28, 28] 0
Conv2d-79 [-1, 256, 28, 28] 131,072
BatchNorm2d-80 [-1, 256, 28, 28] 512
ReLU-81 [-1, 256, 28, 28] 0
Conv2d-82 [-1, 256, 14, 14] 589,824
BatchNorm2d-83 [-1, 256, 14, 14] 512
ReLU-84 [-1, 256, 14, 14] 0
Conv2d-85 [-1, 1024, 14, 14] 262,144
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048
Conv2d-87 [-1, 1024, 14, 14] 524,288
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048
ReLU-89 [-1, 1024, 14, 14] 0
Bottleneck-90 [-1, 1024, 14, 14] 0
Conv2d-91 [-1, 256, 14, 14] 262,144
BatchNorm2d-92 [-1, 256, 14, 14] 512
ReLU-93 [-1, 256, 14, 14] 0
Conv2d-94 [-1, 256, 14, 14] 589,824
BatchNorm2d-95 [-1, 256, 14, 14] 512
ReLU-96 [-1, 256, 14, 14] 0
Conv2d-97 [-1, 1024, 14, 14] 262,144
BatchNorm2d-98 [-1, 1024, 14, 14] 2,048
ReLU-99 [-1, 1024, 14, 14] 0
Bottleneck-100 [-1, 1024, 14, 14] 0
Conv2d-101 [-1, 256, 14, 14] 262,144
BatchNorm2d-102 [-1, 256, 14, 14] 512
ReLU-103 [-1, 256, 14, 14] 0
Conv2d-104 [-1, 256, 14, 14] 589,824
BatchNorm2d-105 [-1, 256, 14, 14] 512
ReLU-106 [-1, 256, 14, 14] 0
Conv2d-107 [-1, 1024, 14, 14] 262,144
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048
ReLU-109 [-1, 1024, 14, 14] 0
Bottleneck-110 [-1, 1024, 14, 14] 0
Conv2d-111 [-1, 256, 14, 14] 262,144
BatchNorm2d-112 [-1, 256, 14, 14] 512
ReLU-113 [-1, 256, 14, 14] 0
Conv2d-114 [-1, 256, 14, 14] 589,824
BatchNorm2d-115 [-1, 256, 14, 14] 512
ReLU-116 [-1, 256, 14, 14] 0
Conv2d-117 [-1, 1024, 14, 14] 262,144
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048
ReLU-119 [-1, 1024, 14, 14] 0
Bottleneck-120 [-1, 1024, 14, 14] 0
Conv2d-121 [-1, 256, 14, 14] 262,144
BatchNorm2d-122 [-1, 256, 14, 14] 512
ReLU-123 [-1, 256, 14, 14] 0
Conv2d-124 [-1, 256, 14, 14] 589,824
BatchNorm2d-125 [-1, 256, 14, 14] 512
ReLU-126 [-1, 256, 14, 14] 0
Conv2d-127 [-1, 1024, 14, 14] 262,144
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048
ReLU-129 [-1, 1024, 14, 14] 0
Bottleneck-130 [-1, 1024, 14, 14] 0
Conv2d-131 [-1, 256, 14, 14] 262,144
BatchNorm2d-132 [-1, 256, 14, 14] 512
ReLU-133 [-1, 256, 14, 14] 0
Conv2d-134 [-1, 256, 14, 14] 589,824
BatchNorm2d-135 [-1, 256, 14, 14] 512
ReLU-136 [-1, 256, 14, 14] 0
Conv2d-137 [-1, 1024, 14, 14] 262,144
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048
ReLU-139 [-1, 1024, 14, 14] 0
Bottleneck-140 [-1, 1024, 14, 14] 0
Conv2d-141 [-1, 512, 14, 14] 524,288
BatchNorm2d-142 [-1, 512, 14, 14] 1,024
ReLU-143 [-1, 512, 14, 14] 0
Conv2d-144 [-1, 512, 14, 14] 2,359,296
BatchNorm2d-145 [-1, 512, 14, 14] 1,024
ReLU-146 [-1, 512, 14, 14] 0
Conv2d-147 [-1, 2048, 14, 14] 1,048,576
BatchNorm2d-148 [-1, 2048, 14, 14] 4,096
Conv2d-149 [-1, 2048, 14, 14] 2,097,152
BatchNorm2d-150 [-1, 2048, 14, 14] 4,096
ReLU-151 [-1, 2048, 14, 14] 0
Bottleneck-152 [-1, 2048, 14, 14] 0
Conv2d-153 [-1, 512, 14, 14] 1,048,576
BatchNorm2d-154 [-1, 512, 14, 14] 1,024
ReLU-155 [-1, 512, 14, 14] 0
Conv2d-156 [-1, 512, 14, 14] 2,359,296
BatchNorm2d-157 [-1, 512, 14, 14] 1,024
ReLU-158 [-1, 512, 14, 14] 0
Conv2d-159 [-1, 2048, 14, 14] 1,048,576
BatchNorm2d-160 [-1, 2048, 14, 14] 4,096
ReLU-161 [-1, 2048, 14, 14] 0
Bottleneck-162 [-1, 2048, 14, 14] 0
Conv2d-163 [-1, 512, 14, 14] 1,048,576
BatchNorm2d-164 [-1, 512, 14, 14] 1,024
ReLU-165 [-1, 512, 14, 14] 0
Conv2d-166 [-1, 512, 14, 14] 2,359,296
BatchNorm2d-167 [-1, 512, 14, 14] 1,024
ReLU-168 [-1, 512, 14, 14] 0
Conv2d-169 [-1, 2048, 14, 14] 1,048,576
BatchNorm2d-170 [-1, 2048, 14, 14] 4,096
ReLU-171 [-1, 2048, 14, 14] 0
Bottleneck-172 [-1, 2048, 14, 14] 0
AdaptiveAvgPool2d-173 [-1, 2048, 4, 1] 0
AdaptiveAvgPool2d-174 [-1, 2048, 1, 1] 0
Linear-175 [-1, 512] 1,049,088
BatchNorm1d-176 [-1, 512] 1,024
Dropout-177 [-1, 512] 0
Linear-178 [-1, 751] 385,263
ClassBlock-179 [-1, 751] 0
================================================================
比较尴尬啊,又大又长。反正我们只要暂时知道,该网络就是输入一张图片,然后其给出这个图片属于的类别,或者ID编号就可以了。估计后续带大家把这个教师模型训练一篇的命运是逃脱不了了,因为感觉项目的落实,非需要他不可。
这篇博客有点水了啊,都没有什么东西讲,就完了,都不好意思叫大叫点赞了。表示很尴尬,不怪我啊,怪公式与网络太简单了!