作者简介:秃头小苏,致力于用最通俗的语言描述问题
往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例 对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例 对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战
近期目标:写好专栏的每一篇文章
支持小苏:点赞、收藏⭐、留言
在上一篇,我为大家介绍了首次应用在缺陷检测中的GAN网络——ANoGAN。在文末总结了AnoGAN一个显而易见的劣势,即在测试阶段需要花费大量时间来搜索潜在变量z,这在很多应用场景中是难以接受的。本文针对上述所说缺点,介绍一种新的GAN网络——EGBAD,其在训练过程中通过一个巧妙的编码器实现对z的搜索,这样在测试过程中就可以节约大量时间。
阅读本文之前,建议先对AnoGAN有一定了解,可参考下文:
如果你准备就绪的话,就让我们一起来学学AnoGAN的改进版EGBAD吧!!!
一直在说EGBAD,大家肯定一脸懵,到底什么才是EGBAD了?我们先来看看它的英文全称,即EFFICIENT GAN-BASED ANOMALY DETECTION
,中文译为基于GAN的高效异常检测。通过说明EGBAD的字面含义,相信大家知道了EGBAD是用来干什么的了。没错,它也是用于缺陷检测的网络,是对AnoGAN的优化。至于具体是怎么优化的,且听下文分解。
我们先来回顾一下AnoGAN是怎么设计的?AnoGAN分为训练和测试两个阶段进行,训练阶段使用正常数据训练一个DCGAN网络,在测试阶段,固定训练阶段的网络权重,不断更新潜在变量z,使得由z生成的假图像尽可能接近真实图片。【如果你对这个过程不熟悉的话,建议看看[1]中内容喔】 在介绍EGBAD是怎么设计的前,我们先来看看EGBAD主要解决了AnoGAN什么问题?其实这点我在写在前面
已经提及,AnoGAN在测试阶段要不断搜索潜在变量z,这消耗了大量时间,EGBAD的提出就是为了解决AnoGAN时间消耗大的问题。接着我们来就来看看EGBAD具体是怎么做的呢?EGBAD也分为训练和测试两个阶段进行。在训练阶段,不仅要训练生成器和判别器,还会定义一个编码器(encoder)结构并对其训练,encoder主要用于将输入图像通过网络转变成一个潜在变量。在测试阶段,冻结训练阶段的所以权重,之后通过encoder将输入图像变为潜在变量,最后在将潜在变量送入生成器,生成假图像。可以发现EGBAD没有在测试阶段搜索潜在变量,而是直接通过一个encoder结构将输入图像转变成潜在变量,这大大节省了时间成本。
关于EGBAD训练过程模型示意图如下:【测试过程很简单啦,就不介绍了】
可以看出判别器的输入有两个,一个是生成器生成的假图像 x ′ {\rm{x'}} x′,另一个是编码器生成的 z ′ {\rm{z'}} z′。具体生成器、编码器和判别器的结构如何,将在下章代码实战中介绍。
同样,我将此部分的源码上传到Github上了,大家可以阅读README文件了解代码的使用,Github地址如下:
EGBAD-pytorch实现
我认为你阅读README文件后已经对这个项目的结构有所了解,我在下文也会帮大家分析分析源码,但更多的时间大家应该自己动手去亲自调试,这样你会有不一样的收获。
这部分和AnoGAN中完全一致,就不带大家一行行看调试结果了,不明白的可以阅读AnoGAN教程,这里直接上代码:
#导入相关包
import numpy as np
import pandas as pd
"""
mnist数据集读取
"""
## 读取训练集数据 (60000,785)
train = pd.read_csv(".\data\mnist_train.csv",dtype = np.float32)
## 读取测试集数据 (10000,785)
test = pd.read_csv(".\data\mnist_test.csv",dtype = np.float32)
# 查询训练数据中标签为7、8的数据,并取前400个
train = train.query("label in [7.0, 8.0]").head(400)
# 查询训练数据中标签为7、8的数据,并取前400个
test = test.query("label in [2.0, 7.0, 8.0]").head(600)
# 取除标签后的784列数据
train = train.iloc[:,1:].values.astype('float32')
test = test.iloc[:,1:].values.astype('float32')
# train:(400,784)-->(400,28,28)
# test:(600,784)-->(600,28,28)
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)
这部分大家就潜心修行,慢慢调试代码吧,我也会给出每个模型的结构图辅助大家,就让我们一起来看看吧☘☘☘
"""定义生成器网络结构"""
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.ReLU(inplace=True), bn=True):
seq = []
seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
if bn is True:
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(20, 64*8, stride=1, padding=0)]
seq += [CBA(64*8, 64*4)]
seq += [CBA(64*4, 64*2)]
seq += [CBA(64*2, 64)]
seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]
self.generator_network = nn.Sequential(*seq)
def forward(self, z):
out = self.generator_network(z)
return out
生成模型的搭建其实很AnoGAN是完全一样的,我也给出生成网络的结构图,如下:
"""定义编码器结构"""
class encoder(nn.Module):
def __init__(self):
super(encoder, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
seq = []
seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(1, 64)]
seq += [CBA(64, 64*2)]
seq += [CBA(64*2, 64*4)]
seq += [CBA(64*4, 64*8)]
seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
self.feature_network = nn.Sequential(*seq)
self.embedding_network = nn.Linear(512, 20)
def forward(self, x):
feature = self.feature_network(x).view(-1, 512)
z = self.embedding_network(feature)
return z
这部分其实也很简单,就是一系列卷积的堆积,编码器的结构图如下:
"""定义判别器网络结构"""
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
seq = []
seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(1, 64)]
seq += [CBA(64, 64*2)]
seq += [CBA(64*2, 64*4)]
seq += [CBA(64*4, 64*8)]
seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
self.feature_network = nn.Sequential(*seq)
seq = []
seq += [nn.Linear(20, 512)]
seq += [nn.BatchNorm1d(512)]
seq += [nn.LeakyReLU(0.1, inplace=True)]
self.latent_network = nn.Sequential(*seq)
self.critic_network = nn.Linear(1024, 1)
def forward(self, x, z):
feature = self.feature_network(x)
feature = feature.view(feature.size(0), -1)
latent = self.latent_network(z)
out = self.critic_network(torch.cat([feature, latent], dim=1))
return out, feature
虽然判别器有两个输入,两个输出,但是结构也非常清晰,如下图所示:
在模型搭建部分我还想提一点我们需要注意的地方,一般我们设计好一个网络结构后,我们往往会先设计一个tensor来作为网络的输入,看看网络输出是否是是我们预期的,如果是,我们再进行下一步,否则我们需要调整我们的结构以适应我们的输入。通常情况下,tensor的batch维度设为1就行,但是这里设置成1就会报错,提示我们需要设置一个batch大于1的整数,当将batch设置为2时,程序正常,至于产生这种现象的原因我目前也不是很清楚,大家注意一下,知道的也烦请告知一下。关于调试网络结构是否正常的代码如下,仅供参考:
if __name__ == '__main__':
x = torch.ones((2, 1, 64, 64))
z = torch.ones((2, 20, 1, 1))
Generator = Generator()
Discriminator = Discriminator()
encoder = encoder()
output_G = Generator(z)
output_D1, output_D2= Discriminator(x, z.view(2, -1))
output_E = encoder(x)
print(output_G.shape)
print(output_D1.shape)
print(output_D2.shape)
print(output_E.shape)
这部分和AnoGAN一致,注意最终输入网络的图片尺寸都上采样成了64×64.
class image_data_set(Dataset):
def __init__(self, data):
self.images = data[:,:,:,None]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),
transforms.Normalize((0.1307,), (0.3081,))
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.transform(self.images[idx])
# 加载训练数据
train_set = image_data_set(train)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
这部分也基本和AnoGAN类似,只不过添加了encoder网络的定义和优化器定义部分,如下:
# 指定设备
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# batch_size默认128
batch_size = args.batch_size
# 加载模型
G = Generator().to(device)
D = Discriminator().to(device)
E = Encoder().to(device)
# 训练模式
G.train()
D.train()
E.train()
# 设置优化器
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerE = torch.optim.Adam(E.parameters(), lr=0.0004, betas=(0.0,0.9))
# 定义损失函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
"""
训练
"""
# 开始训练
for epoch in range(args.epochs):
# 定义初始损失
log_g_loss, log_d_loss, log_e_loss = 0.0, 0.0, 0.0
for images in train_loader:
images = images.to(device)
## 训练判别器 Discriminator
# 定义真标签(全1)和假标签(全0) 维度:(batch_size)
label_real = torch.full((images.size(0),), 1.0).to(device)
label_fake = torch.full((images.size(0),), 0.0).to(device)
# 定义潜在变量z 维度:(batch_size,20,1,1)
z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
# 潜在变量喂入生成网络--->fake_images:(batch_size,1,64,64)
fake_images = G(z)
# 使用编码器将真实图像变成潜在变量 image:(batch_size, 1, 64, 64)-->z_real:(batch_size, 20)
z_real = E(images)
# 真图像和假图像送入判别网络,得到d_out_real、d_out_fake 维度:都为(batch_size,1)
d_out_real, _ = D(images, z_real)
d_out_fake, _ = D(fake_images, z.view(images.size(0), 20))
# 损失计算
d_loss_real = criterion(d_out_real.view(-1), label_real)
d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
d_loss = d_loss_real + d_loss_fake
# 误差反向传播,更新损失
optimizerD.zero_grad()
d_loss.backward()
optimizerD.step()
## 训练生成器 Generator
# 定义潜在变量z 维度:(batch_size,20,1,1)
z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
fake_images = G(z)
# 假图像喂入判别器,得到d_out_fake 维度:(batch_size,1)
d_out_fake, _ = D(fake_images, z.view(images.size(0), 20))
# 损失计算
g_loss = criterion(d_out_fake.view(-1), label_real)
# 误差反向传播,更新损失
optimizerG.zero_grad()
g_loss.backward()
optimizerG.step()
## 训练编码器Encode
# 使用编码器将真实图像变成潜在变量 image:(batch_size, 1, 64, 64)-->z_real:(batch_size, 20)
z_real = E(images)
# 真图像送入判别器,记录结果d_out_real:(128, 1)
d_out_real, _ = D(images, z_real)
# 损失计算
e_loss = criterion(d_out_real.view(-1), label_fake)
# 误差反向传播,更新损失
optimizerE.zero_grad()
e_loss.backward()
optimizerE.step()
## 累计一个epoch的损失,判别器损失、生成器损失、编码器损失分别存放到log_d_loss、log_g_loss、log_e_loss中
log_d_loss += d_loss.item()
log_g_loss += g_loss.item()
log_e_loss += e_loss.item()
## 打印损失
print(f'epoch {epoch}, D_Loss:{log_d_loss/128:.4f}, G_Loss:{log_g_loss/128:.4f}, E_Loss:{log_e_loss/128:.4f}')
这里总结一下上述训练的步骤,不断循环下列过程:
1、使用生成器从潜在变量z中创建假图像
2、使用编码器从真实图像中创建潜在变量
3、生成器和编码器结果送入判别器,进行训练
4、使用生成器从潜在变量z中创建假图像
5、训练生成器
6、使用编码器从真实图像中创建潜在变量
7、训练编码器
关于第3步,我也简单画了个图帮大家理解下,如下:
最后我们来展示一下生成图片的效果,如下图所示:
EGBAD缺陷检测非常简单,首先定义一个就算损失的函数,如下:
## 定义缺陷计算的得分
def anomaly_score(input_image, fake_image, z_real, D):
# Residual loss 计算
residual_loss = torch.sum(torch.abs(input_image - fake_image), (1, 2, 3))
# Discrimination loss 计算
_, real_feature = D(input_image, z_real)
_, fake_feature = D(fake_image, z_real)
discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), (1))
# 结合Residual loss和Discrimination loss计算每张图像的损失
total_loss_by_image = 0.9 * residual_loss + 0.1 * discrimination_loss
return total_loss_by_image
接着我们只需要用Encoder网络生成潜在变量,在再用生成器即可得到假图像,最后计算假图像和真图像的损失即可,如下:
# 加载测试数据
test_set = image_data_set(test)
test_loader = DataLoader(test_set, batch_size=5, shuffle=False)
input_images = next(iter(test_loader)).to(device)
# 通过编码器获取潜在变量,并用生成器生成假图像
z_real = E(input_images)
fake_images = G(z_real.view(input_images.size(0), 20, 1, 1))
# 异常计算
anomality = anomaly_score(input_images, fake_images, z_real, D)
print(anomality.cpu().detach().numpy())
最后可以保存一下真实图像和假图像的结果,如下:
torchvision.utils.save_image(input_images, f"result/Nomal.jpg")
torchvision.utils.save_image(fake_images, f"result/ANomal.jpg")
我们来看一下结果:
通过上图你发现了什么呢?是不是发现输入图像为7的图片的生成图像不是7而变成了8呢,究其原因,应该是生成器学到了更多关于数据8的特征,也就是说这个网络的生成效果并没有很好。
我做了很多实验,发现EGBAD虽然测试时间上比AnoGAN快很多,但是它的稳定性似乎并没有很理想,很容易出现模式崩溃的问题。其实啊,GAN网络普遍存在着训练不稳定的现象,这也是一些大牛不断探索的方向,后面的文章我也会给大家介绍一些增加GAN训练稳定性的文章,敬请期待吧!
我们一直说EGBAD的测试时间相较AnoGAN短,从原理上来说确实是这样,但是具体是不是这样我们还要以实验为准。测试代码也很简单,只需要在测试过程中使用time.time()
函数即可,具体可以参考我上传github中的源码,这里给出我测试两种网络在测试阶段所用时间(以秒为单位),如下图所示:
通过上图数据可以看出,EGBAD比AnoGAN快的不是一点点,EGBAD的速度将近是AnoGAN的10000倍,这个数字还是很恐怖的。
到此,EGBAD的全部内容就为大家介绍完了,如果你明白了AnoGAN的话,这篇文章对你来说应该是小菜一碟了。EGBAD大大的减少了测试所有时间,但是GAN网络普遍存在易模式崩溃、训练不稳定的现象,下一篇博文我将为大家介绍一些让GAN训练更稳定的技巧,敬请期待吧。
EFFICIENT GAN-BASED ANOMALY DETECTION
GAN 使用 Pytorch 进行异常检测的方法
如若文章对你有所帮助,那就