GAN的量化评估方法——IS和FID,及其pytorch代码

GAN的量化评估方法

  • IS
    • IS简介
    • IS代码
  • FID
    • FID简介
    • FID代码

IS

IS基于谷歌的Inception Net-V3,输入是图像,输出是1000维的向量,输出响亮的每个维度,代表着对应的属于某一类的概率。
IS用来衡量GAN网络的两个指标:

  1. 生成图片的质量
  2. 多样性

IS简介

定义:
定义
推导出上式的意义:
GAN的量化评估方法——IS和FID,及其pytorch代码_第1张图片

  1. 对于单一的生成图像,Inceptoin输出的概率分布应该尽量小,越小说明生成图像越可能属于某个类别,图像的质量越高。
  2. 对于生成器生成一批图像而言,Inception输出的平均概率分布熵值应该尽量大,代表着生成器生成的多样性。

IS代码

参考代码:https://github.com/xml94/open/blob/master/compute_IS_for_GAN
本着能不动手就不动手的原则,试了试上面的代码。但是这个需要自己写dataloader函数,还要与代码中匹配,我试了半天也没有成功,所以就自己参考这个写了一个。
只需要把要测试的图片的路径放入path即可:
于此对应的datakoader函数见下面。

from datasets import *

import torch.nn as nn
import torch.nn.functional as F
import torch

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data
from scipy.stats import entropy
from torchvision.models.inception import inception_v3


path = '/'
count = 0
for root,dirs,files in os.walk(path):    #遍历统计
      for each in files:
             count += 1   #统计文件夹下文件个数
print(count)
batch_size = 64
transforms_ = [
    transforms.Resize((256, 256), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]

val_dataloader = DataLoader(
    ISImageDataset(path, transforms_=transforms_),
    batch_size = batch_size,
)

cuda = True if torch.cuda.is_available() else False
print('cuda: ',cuda)
tensor = torch.cuda.FloatTensor

inception_model = inception_v3(pretrained=True, transform_input=False).cuda()
inception_model.eval()
up = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False).cuda()

def get_pred(x):
    if True:
        x = up(x)
    x = inception_model(x)
    return F.softmax(x, dim=1).data.cpu().numpy()

print('Computing predictions using inception v3 model')
preds = np.zeros((count, 1000))

for i, data in enumerate(val_dataloader):
    data = data.type(tensor)
    batch_size_i = data.size()[0]
    preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(data)

print('Computing KL Divergence')
split_scores = []
splits=10
N = count
for k in range(splits):
    part = preds[k * (N // splits): (k + 1) * (N // splits), :] # split the whole data into several parts
    py = np.mean(part, axis=0)  # marginal probability
    scores = []
    for i in range(part.shape[0]):
        pyx = part[i, :]  # conditional probability
        scores.append(entropy(pyx, py))  # compute divergence
    split_scores.append(np.exp(np.mean(scores)))


mean, std  = np.mean(split_scores), np.std(split_scores)
print('IS is %.4f' % mean)
print('The std is %.4f' % std)

dataloader结构体:

import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class ISImageDataset(Dataset):
    def __init__(self, root, transforms_=None):
        self.transform = transforms.Compose(transforms_)

        self.files = sorted(glob.glob(os.path.join(root) + "/*.jpg"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)]).convert('RGB')      
        item_image = self.transform(img)
        return item_image

    def __len__(self):
        return len(self.files)

将结构体放在为:dataset.py中,IS代码中将其import进来。
最终会输出两个值,一个是IS,一个是std。

但是听说:

由于 Inception V3 是在 ImageNet 上训练的,用 Inception V3 时,应该保证生成模型也在 ImageNet上训练并生成 ImageNet 相似的图片,而不是把什么生成模型生成的图片(卧室,花,人脸)都往 Inception V3中套,那种做法没有任何意义。
不能在一个数据集上训练分类模型,用来评估另一个数据集上训练的生成模型

FID

FID分数是在IS基础上改进的,同样基于Inception Net-V3,它删除了模型原本的输出层,于是输出层变成了最后一层池化层,输出是2048维向量,因此每个图像都被预测为2048个特征。

FID简介

Frechet Inception 距离得分(Frechet Inception Distance score,FID)是计算真实图像和生成图像的特征向量之间距离的一种度量。
假如一个随机变量服从高斯分布,这个分布可以用一个均值和方差来确定。那么两个分布只要均值和方差相同,则两个分布相同。我们就利用这个均值和方差来计算这两个单变量高斯分布之间的距离。但我们这里是多维的分布,我们知道协方差矩阵可以用来衡量两个维度之间的相关性。所以,我们使用均值和协方差矩阵来计算两个分布之间的距离
在这里插入图片描述

FID越小代表着生成分布和真实图片之间越接近。

FID代码

可以通过pip之间安装:

pip install pytorch-fid

配置要求如下:

python3
pytorch
torchvision
pillow
numpy
scipy

使用非常的简单:

python -m pytorch_fid path/to/dataset1 path/to/dataset2

把生成图片的路径和真实图片的路径放进去即可,和顺序无关。
也可以选择与–dims N标志一起使用的特征维数,其中N是特征的维数。

64: first max pooling features
192: second max pooling featurs
768: pre-aux classifier features
2048: final average pooling features (this is the default)

比如:

python -m pytorch_fid path/to/dataset1 path/to/dataset2 --dims 2048

一般都是使用默认的2048
FID参考链接:官方github

推荐博客:推荐
英文的:英文

一般的评价图像质量的指标还有SSIM和PSNR,可以参看SSIM和PSNR

你可能感兴趣的:(python,python,机器学习,IS,FID)