结构相似形特征(SSIM)原理简介及python实现

结构相似形特征是图像全参考评价(FR-IQA)中经典的一个方法,由Zhou Wang等人在2004年发表的论文《Image Quality Assessment: From Error Visibility to Structural Similarity》中提出。作为一种全参考方法,需要同时利用原始图像(未失真)和失真图像。

  • 原理简介

SSIM方法从三个方面综合评价图像质量:亮度相似性、对比度相似性、结构相似形。如下图所示:
结构相似形特征(SSIM)原理简介及python实现_第1张图片

1.亮度相似性的定义如下:

在这里插入图片描述
其中:
在这里插入图片描述

C1是一个防止分母为0的常数,作者定义C1如下:
在这里插入图片描述
K1是一个常数,L是图像的灰度范围,比如255.
2. 对比度相似性定义如下:
在这里插入图片描述
其中:
在这里插入图片描述
C2的作用与C1相同,定义为C2=(K2 x L)^2
3.结构相似性定义如下:
在这里插入图片描述
其中:
在这里插入图片描述
而C3的作用与C1、C2相同,定义为C3=(K3 x L)^2

最后,SSIM的公式如下:
在这里插入图片描述
α、β、γ是调整三种相似性重要程度的系数。特别地,当α=β=γ=1时,且C3=C2/2时,SSIM可化为以下形式:
在这里插入图片描述
这也是常用的SSIM形式。需要指出,SSIM计算实在每个图像局部块上,最后需要pooling来得出整幅图像分数,论文采用的是简单的average pooling方式,即求平均值。

  • Python实现

作者实际上采用σ=1.5的高斯加权函数计算每个图像块(图像块大小11x11)的均值和标准差,目的是防止undesirable “blocking” artifacts:
结构相似形特征(SSIM)原理简介及python实现_第2张图片
生成高斯核的函数参考了RaymondMG的博客
由于使用的高斯函数圆对称,因此相关操作和卷积操作结果相同,这里只定义了相关操作,卷积在相关操作前将核旋转180°即可。

import cv2
import numpy as np
import time
from numba import jit,njit

#相关操作
#由于使用的高斯函数圆对称,因此相关操作和卷积操作结果相同
@njit
def correlation(img,kernal):
    kernal_heigh = kernal.shape[0]
    kernal_width = kernal.shape[1]
    cor_heigh = img.shape[0] - kernal_heigh + 1
    cor_width = img.shape[1] - kernal_width + 1
    result = np.zeros((cor_heigh, cor_width), dtype=np.float64)
    for i in range(cor_heigh):
        for j in range(cor_width):
            result[i][j] = (img[i:i + kernal_heigh, j:j + kernal_width] * kernal).sum()
    return result

#产生二维高斯核函数
#这个函数参考自:https://blog.csdn.net/qq_16013649/article/details/78784791
@jit
def gaussian_2d_kernel(kernel_size=11, sigma=1.5):
    kernel = np.zeros([kernel_size, kernel_size])
    center = kernel_size // 2

    if sigma == 0:
        sigma = ((kernel_size - 1) * 0.5 - 1) * 0.3 + 0.8

    s = 2 * (sigma ** 2)
    sum_val = 0
    for i in range(0, kernel_size):
        for j in range(0, kernel_size):
            x = i - center
            y = j - center
            kernel[i, j] = np.exp(-(x ** 2 + y ** 2) / s)
            sum_val += kernel[i, j]
    sum_val = 1 / sum_val
    return kernel * sum_val


#ssim模型
@jit
def ssim(distorted_image,original_image,window_size=11,gaussian_sigma=1.5,K1=0.01,K2=0.03,alfa=1,beta=1,gama=1):
    distorted_image=np.array(distorted_image,dtype=np.float64)
    original_image=np.array(original_image,dtype=np.float64)
    if not distorted_image.shape == original_image.shape:
        raise ValueError("Input Imagees must has the same size")
    if len(distorted_image.shape) > 2:
        raise ValueError("Please input the images with 1 channel")
    kernal=gaussian_2d_kernel(window_size,gaussian_sigma)

    #求ux uy ux*uy ux^2 uy^2 sigma_x^2 sigma_y^2 sigma_xy等中间变量
    ux=correlation(distorted_image,kernal)
    uy=correlation(original_image,kernal)
    distorted_image_sqr=distorted_image**2
    original_image_sqr=original_image**2
    dis_mult_ori=distorted_image*original_image
    uxx=correlation(distorted_image_sqr,kernal)
    uyy=correlation(original_image_sqr,kernal)
    uxy=correlation(dis_mult_ori,kernal)
    ux_sqr=ux**2
    uy_sqr=uy**2
    uxuy=ux*uy
    sx_sqr=uxx-ux_sqr
    sy_sqr=uyy-uy_sqr
    sxy=uxy-uxuy
    C1=(K1*255)**2
    C2=(K2*255)**2
    #常用情况的SSIM
    if(alfa==1 and beta==1 and gama==1):
        ssim=(2*uxuy+C1)*(2*sxy+C2)/(ux_sqr+uy_sqr+C1)/(sx_sqr+sy_sqr+C2)
        return np.mean(ssim)
    #计算亮度相似性
    l=(2*uxuy+C1)/(ux_sqr+uy_sqr+C1)
    l=l**alfa
    #计算对比度相似性
    sxsy=np.sqrt(sx_sqr)*np.sqrt(sy_sqr)
    c=(2*sxsy+C2)/(sx_sqr+sy_sqr+C2)
    c=c**beta
    #计算结构相似性
    C3=0.5*C2
    s=(sxy+C3)/(sxsy+C3)
    s=s**gama
    ssim=l*c*s
    return np.mean(ssim)

另外,代码中的@njit和@jit可以注释掉,我在这里是为了加速代码运行(实际加速效果不大)。
此外,Tensorflow也有计算SSIM的函数:

# Read images from file.
im1 = tf.decode_png('path/to/im1.png')
im2 = tf.decode_png('path/to/im2.png')
# Compute SSIM over tf.uint8 Tensors.
ssim1 = tf.image.ssim(im1, im2, max_val=255)

# Compute SSIM over tf.float32 Tensors.
im1 = tf.image.convert_image_dtype(im1, tf.float32)
im2 = tf.image.convert_image_dtype(im2, tf.float32)
ssim2 = tf.image.ssim(im1, im2, max_val=1.0)
# ssim1 and ssim2 both have type tf.float32 and are almost equal.

你可能感兴趣的:(Image/Video,Quality,Assessment)