生成专题2 | 图像生成评价指标FID

  • 文章转自微信公众号:机器学习炼丹术
  • 作者:陈亦新(欢迎交流共同进步)
  • 联系方式:微信cyx645016617

文章目录

    • 2.1 感性理解
    • 2.2 代码实现

2.1 感性理解

FID是Fréchet Inception Distance。

FID依然是表示生成图像的多样性和质量,为什么FID越小,则图像多样性越好,质量越好。

FID的计算器中,我们也是用了inception network网络。inception netowrk其实就是特征提取的网络,最后一层输出图像的类别。不过我们会去除最后的全连接或者池化层,使得我们得到一个2048维度的特征

对于我们已经拥有的真实图片,所有真实图片的提取的向量是服从一个分布的;对于用GAN生成的图片对应的高位向量特征也是服从一个分布的。如果两个分布相同,那么意味着GAN生成图片的真实程度很高。

现在,我们如何计算两个分布的距离呢?因为这两个分布是多变量的,包含2048维度的特征,所以我们是计算两个多维变量分布之间的距离。可以使用Wasserstein距离或者Frechet距离。

假如一个随机变量服从高斯分布,那么这个分布可以用一个均值和方差来确定。那么两个分布只要均值和方差相同,那么两个分布则相同。我们可以利用均值和方差来计算两个单变量高斯分布之间的距离。这里是多维度的分布,我们可以使用协方差矩阵来衡量多个维度之间的相关性,所以使用均值和协方差矩阵来计算两个高维分布之间的距离。

我们下面公式计算FID:

F I D ( x , g ) = ∣ ∣ μ x − μ g ∣ ∣ 2 2 + T r ( Σ x + Σ g − 2 ( Σ x Σ g ) 0.5 ) FID(x,g)=||\mu_x-\mu_g||^2_2+Tr(\Sigma_x+\Sigma_g-2(\Sigma_x\Sigma_g)^{0.5}) FID(x,g)=μxμg22+Tr(Σx+Σg2(ΣxΣg)0.5)

公式中, T r Tr Tr表示矩阵对角线上元素的综合,矩阵论中成为矩阵的迹。x和g表示真实的图片和生成的图片, μ \mu μ表示均值, σ \sigma σ是协方差矩阵。

较低的FID表示两个分布更为接近。

下面是使用Numpy实现FID的计算过程:

2.2 代码实现

# calculate frechet inception distance
def calculate_fid(act1, act2):
	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
	# calculate sum squared difference between means
	ssdiff = numpy.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
	covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
	if iscomplexobj(covmean):
		covmean = covmean.real
	# calculate score
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid
  
# define two collections of activations
act1 = random(10*2048)
act1 = act1.reshape((10,2048))
act2 = random(10*2048)
act2 = act2.reshape((10,2048))
# fid between act1 and act1
fid = calculate_fid(act1, act1)
print('FID (same): %.3f' % fid)
# fid between act1 and act2
fid = calculate_fid(act1, act2)
print('FID (different): %.3f' % fid)

你可能感兴趣的:(笔记,python,数据挖掘,神经网络)