基于pytorch计算ssim和ms-ssim

使用pytorch计算两组图片的ssim和ms-ssim

首先是SSIM和MS-SSIM类(ssim.py)
import torch
import torch.nn.functional as F

def _fspecial_gauss_1d(size, sigma):
    coords = torch.arange(size).to(dtype=torch.float)
    coords -= size//2
    g = torch.exp(-(coords**2) / (2*sigma**2))
    g /= g.sum()
    return g.unsqueeze(0).unsqueeze(0)
    
def gaussian_filter(input, win):
    N, C, H, W = input.shape
    out = F.conv2d(input, win, stride=1, padding=0, groups=C)
    out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
    return out


def _ssim(X, Y, win, data_range=1023, size_average=True, full=False):
    K1 = 0.01
    K2 = 0.03
    batch, channel, height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range)**2
    C2 = (K2 * data_range)**2

    win = win.to(X.device, dtype=X.dtype)

    mu1 = gaussian_filter(X, win)
    mu2 = gaussian_filter(Y, win)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * ( gaussian_filter(X * X, win) - mu1_sq )
    sigma2_sq = compensation * ( gaussian_filter(Y * Y, win) - mu2_sq )
    sigma12   = compensation * ( gaussian_filter(X * Y, win) - mu1_mu2 )

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    if size_average:
        ssim_val = ssim_map.mean()
        cs = cs_map.mean()
    else:
        ssim_val = ssim_map.mean(-1).mean(-1).mean(-1)  # reduce along CHW
        cs = cs_map.mean(-1).mean(-1).mean(-1)

    if full:
        return ssim_val, cs
    else:
        return ssim_val


def ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False):

    if len(X.shape) != 4:
        raise ValueError('Input images must 4-d tensor.')

    if not X.type() == Y.type():
        raise ValueError('Input images must have the same dtype.')

    if not X.shape == Y.shape:
        raise ValueError('Input images must have the same dimensions.')

    if not (win_size % 2 == 1):
        raise ValueError('Window size must be odd.')

    win_sigma = win_sigma
    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat(X.shape[1], 1, 1, 1)
    else:
        win_size = win.shape[-1]

    ssim_val, cs = _ssim(X, Y,
                         win=win,
                         data_range=data_range,
                         size_average=False,
                         full=True)
    if size_average:
        ssim_val = ssim_val.mean()
        cs = cs.mean()

    if full:
        return ssim_val, cs
    else:
        return ssim_val


def ms_ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False, weights=None):
    if len(X.shape) != 4:
        raise ValueError('Input images must 4-d tensor.')

    if not X.type() == Y.type():
        raise ValueError('Input images must have the same dtype.')

    if not X.shape == Y.shape:
        raise ValueError('Input images must have the same dimensions.')

    if not (win_size % 2 == 1):
        raise ValueError('Window size must be odd.')

    if weights is None:
        weights = torch.FloatTensor(
            [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype)

    win_sigma = win_sigma
    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat(X.shape[1], 1, 1, 1)
    else:
        win_size = win.shape[-1]

    levels = weights.shape[0]
    mcs = []
    for _ in range(levels):
        ssim_val, cs = _ssim(X, Y,
                             win=win,
                             data_range=data_range,
                             size_average=False,
                             full=True)
        mcs.append(cs)

        padding = (X.shape[2] % 2, X.shape[3] % 2)
        X = F.avg_pool2d(X, kernel_size=2, padding=padding)
        Y = F.avg_pool2d(Y, kernel_size=2, padding=padding)

    mcs = torch.stack(mcs, dim=0)  # mcs, (level, batch)
    # weights, (level)
    msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1))
                            * (ssim_val ** weights[-1]), dim=0)  # (batch, )

    if size_average:
        msssim_val = msssim_val.mean()
    return msssim_val


# Classes to re-use window
class SSIM(torch.nn.Module):
    def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3):
        super(SSIM, self).__init__()
        self.win = _fspecial_gauss_1d(
            win_size, win_sigma).repeat(channel, 1, 1, 1)
        self.size_average = size_average
        self.data_range = data_range

    def forward(self, X, Y):
        return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average)


class MS_SSIM(torch.nn.Module):
    def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3, weights=None):
        super(MS_SSIM, self).__init__()
        self.win = _fspecial_gauss_1d(
            win_size, win_sigma).repeat(channel, 1, 1, 1)
        self.size_average = size_average
        self.data_range = data_range
        self.weights = weights

    def forward(self, X, Y):
        return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)

上面的工具类我在pytorch中当做损失函数使用

使用

这里还用到几个方法,我在下面给出

import argparse
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from util import (map_range, cv2torch, random_tone_map,
                  DirectoryDataset, str2bool)、
#这里我把上面的ssim.py放到了一个文件夹中,所以需要这样导入
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--batch_size', type=int, default=1, help='Batch size.')
    parser.add_argument(
        '-d', '--data_root_path_label', default='D:/project_hdr/hdr-expandnet/test_data_hdr', help='Path to hdr data.')
    parser.add_argument(
        '-v', '--data_root_path_pre', default='D:/project_hdr/myGan_same_size_wights_noexpand _minD/test_data_ldr/981',
        help='Path to hdr data.')
    parser.add_argument(
        '--num_workers',
        type=int,
        default=1,
        help='Number of data loading workers.')

    return parser.parse_args()



def transformh(hdr):
    hdr = map_range(hdr)
    return cv2torch(hdr)

def train(opt):

    # 加载训练集
    dataset1 = DirectoryDataset(
        data_root_path=opt.data_root_path_label, preprocess=transformh)
    loader1 = DataLoader(
        dataset1,
        batch_size=opt.batch_size,
        num_workers=opt.num_workers,)
    # 加载训练集
    dataset2 = DirectoryDataset(
        data_root_path=opt.data_root_path_pre, preprocess=transformh)
    loader2 = DataLoader(
        dataset2,
        batch_size=opt.batch_size,
        num_workers=opt.num_workers,)

    
    for (ldr_in, hdr_target) in zip(loader2, loader1):
        if torch.cuda.is_available():
            ldr_in = ldr_in.cuda()
            hdr_target = hdr_target.cuda()
        pre = ldr_in
        real_B = hdr_target

        ssim_val = ssim(real_B, pre, data_range=1, size_average=True,  )  # return (N,)
        ms_ssim_val = ms_ssim(real_B, pre, data_range=1, size_average=True,)  # (N,)

        rep = (f'ssim_val: {ssim_val},'
               f'ms_ssim_val: {ms_ssim_val},')
        tqdm.write(rep)


if __name__ == '__main__':
    opt = parse_args()
    train(opt)

下面是使用到的几个方法

def map_range(x, low=0, high=1):
    return np.interp(x, [x.min(), x.max()], [low, high]).astype(x.dtype)
var foo = 'bar';
def cv2torch(np_img):
    rgb = np_img[:, :, (2, 1, 0)]
    return torch.from_numpy(rgb.swapaxes(1, 2).swapaxes(0, 1))

下面这个类需要opencv

class DirectoryDataset(Dataset):
    def __init__(self,
                 data_root_path='hdr_data',
                 data_extensions=['.hdr', '.exr'],
                 load_fn=None,
                 preprocess=None):
        super(DirectoryDataset, self).__init__()
        data_root_path = process_path(data_root_path)
        self.file_list = []
        for root, _, fnames in sorted(os.walk(data_root_path)):
            for fname in fnames:
                if any(fname.lower().endswith(extension)
                       for extension in data_extensions):
                    self.file_list.append(os.path.join(root, fname))
        if len(self.file_list) == 0:
            msg = 'Could not find any files with extensions:\n[{0}]\nin\n{1}'
            raise RuntimeError(
                msg.format(', '.join(data_extensions), data_root_path))

        self.preprocess = preprocess

    def __getitem__(self, index):
        dpoint = cv2.imread(
            self.file_list[index],
            flags=cv2.IMREAD_ANYDEPTH + cv2.IMREAD_COLOR)
        if self.preprocess is not None:
            dpoint = self.preprocess(dpoint)
        return dpoint

    def __len__(self):
        return len(self.file_list)

def process_path(directory, create=False):
    directory = os.path.expanduser(directory)
    directory = os.path.normpath(directory)
    directory = os.path.abspath(directory)
    if create:
        try:
            os.makedirs(directory)
        except:
            pass
    return directory

大概就是这些,msssim代码来源于github : https://github.com/VainF/pytorch-msssim
当然,计算ssim还有一些其他更加简便的方法,比如

from skimage.measure import compare_ssim
(score, diff) = compare_ssim(X, Y, full=True)
diff = (diff * 255).astype("float32")

但是python计算msssim的就没有了,所以我用的这个单独的方法去计算的
基于pytorch计算ssim和ms-ssim_第1张图片

你可能感兴趣的:(计算机视觉,pytorch,ssim,ms-ssim)