这个Python脚本主要用于评估图像质量,它比较了一组高清(HD)图像和对应的生成图像,并计算了四种不同的图像质量指标:PSNR、SSIM、LPIPS和NIQE。
在代码开始,使用了LPIPS库来初始化一个预训练的VGG网络,这用于后续的LPIPS图像质量评估。
loss_fn = lpips.LPIPS(net='vgg')
这两个函数分别使用OpenCV和skimage库来计算PSNR和SSIM。这些都是全参考指标,需要原图和生成图进行比较。
这个函数使用初始化的LPIPS模型来评估两个图像(原图和生成图)之间的感知差异。
这个函数实现了NIQE(无参考图像质量评估),即只需要一个图像即可评估其质量。
这个函数是代码的核心,它执行以下操作:
cv2.imread
方法。使用Python的multiprocessing.Pool
来并行处理所有高清图像。这是一种典型的"Map-Reduce"模式,其中process_image
函数是map操作。
with Pool(4) as pool: # Initialize a pool with 4 processes
pool.starmap(process_image, [(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder) for i in range(1, 570)])
import os
from multiprocessing import Pool
import cv2
import lpips
import numpy as np
import torch
from scipy.ndimage import filters
from scipy.special import gammaln
from scipy.stats import genpareto
from skimage import img_as_float
from skimage.metrics import structural_similarity as compare_ssim
from tqdm import tqdm
# Initialize LPIPS
loss_fn = lpips.LPIPS(net='vgg')
def calculate_psnr(img1, img2):
return cv2.PSNR(img1, img2)
def calculate_ssim(img1, img2):
return compare_ssim(img1, img2, multichannel=True)
def calculate_lpips(img1, img2):
img1 = torch.Tensor(img1).permute(2, 0, 1).unsqueeze(0)
img2 = torch.Tensor(img2).permute(2, 0, 1).unsqueeze(0)
return loss_fn(img1, img2).item()
def calculate_niqe(image):
image = img_as_float(image)
h, w = image.shape[:2]
block_size = 96
strides = 32
features = []
for i in range(0, h - block_size + 1, strides):
for j in range(0, w - block_size + 1, strides):
block = image[i:i + block_size, j:j + block_size]
mu = np.mean(block)
sigma = np.std(block)
filtered_block = filters.gaussian_filter(block, sigma)
shape, _, scale = genpareto.fit(filtered_block.ravel(), floc=0)
feature = [mu, sigma, shape, scale, gammaln(1 / shape)]
features.append(feature)
features = np.array(features)
model_mean = np.zeros(features.shape[1])
model_cov_inv = np.eye(features.shape[1])
quality_scores = []
for feature in features:
score = (feature - model_mean) @ model_cov_inv @ (feature - model_mean).T
quality_scores.append(score)
return np.mean(quality_scores)
def process_image(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder):
hd_img_name = f"{i}.png"
hd_img_path = os.path.join(hd_img_folder, hd_img_name)
hd_img = cv2.imread(hd_img_path)
corresponding_generated_folder = os.path.join(generated_img_root_folder, str(i))
if not os.path.exists(corresponding_generated_folder):
print(f"Folder for {hd_img_name} does not exist. Skipping.")
return
output_file_path = os.path.join(output_root_folder, f"{i}_output.txt")
generated_img_names = os.listdir(corresponding_generated_folder)
generated_img_names.sort(key=lambda x: int(x.split('.')[0]))
total_images = len(generated_img_names)
best_psnr = 0
best_ssim = 0
best_lpips = float('inf')
best_niqe = float('inf')
best_metrics_record = {}
with open(output_file_path, 'w') as f:
f.write(f"Results for HD Image: {hd_img_name}\n")
f.write("-------------------------------------\n")
for generated_img_name in tqdm(generated_img_names, total=total_images, desc=f"Processing for {hd_img_name}",
leave=False):
generated_img_path = os.path.join(corresponding_generated_folder, generated_img_name)
generated_img = cv2.imread(generated_img_path)
psnr = calculate_psnr(hd_img, generated_img)
ssim = calculate_ssim(hd_img, generated_img)
lpips_value = calculate_lpips(hd_img, generated_img)
niqe = calculate_niqe(generated_img)
result_line = f"{generated_img_name} PSNR: {psnr} SSIM: {ssim} LPIPS: {lpips_value} NIQE: {niqe}\n"
f.write(result_line)
if psnr > best_psnr:
best_psnr = psnr
best_metrics_record['Best PSNR'] = (generated_img_name, best_psnr)
if ssim > best_ssim:
best_ssim = ssim
best_metrics_record['Best SSIM'] = (generated_img_name, best_ssim)
if lpips_value < best_lpips:
best_lpips = lpips_value
best_metrics_record['Best LPIPS'] = (generated_img_name, best_lpips)
if niqe < best_niqe:
best_niqe = niqe
best_metrics_record['Best NIQE'] = (generated_img_name, best_niqe)
with open(main_output_file_path, 'a') as main_f:
main_f.write(f"Best Metrics for {hd_img_name}\n")
main_f.write("-------------------------------------\n")
for metric, (img_name, value) in best_metrics_record.items():
main_f.write(f"{metric}: {img_name}, Value: {value}\n")
main_f.write("\n")
print(f"Best Metrics for {hd_img_name} are saved in {main_output_file_path}")
if __name__ == "__main__":
hd_img_folder = 'xxxxxxx'
generated_img_root_folder = 'xxxxxxx'
output_root_folder = 'xxxxxxx'
main_output_file_path = os.path.join(output_root_folder, "all_results.txt")
with open(main_output_file_path, 'w') as main_f:
main_f.write("Detailed Results for Each High-Definition Image\n")
main_f.write("==============================================\n")
with Pool(4) as pool: # Initialize a pool with 8 processes
pool.starmap(process_image,
[(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder) for i in
range(1, 570)]) # Parallel processing