import torch
import torch.nn.functional as F
def _fspecial_gauss_1d(size, sigma):
coords = torch.arange(size).to(dtype=torch.float)
coords -= size//2
g = torch.exp(-(coords**2) / (2*sigma**2))
g /= g.sum()
return g.unsqueeze(0).unsqueeze(0)
def gaussian_filter(input, win):
N, C, H, W = input.shape
out = F.conv2d(input, win, stride=1, padding=0, groups=C)
out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
return out
def _ssim(X, Y, win, data_range=1023, size_average=True, full=False):
K1 = 0.01
K2 = 0.03
batch, channel, height, width = X.shape
compensation = 1.0
C1 = (K1 * data_range)**2
C2 = (K2 * data_range)**2
win = win.to(X.device, dtype=X.dtype)
mu1 = gaussian_filter(X, win)
mu2 = gaussian_filter(Y, win)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = compensation * ( gaussian_filter(X * X, win) - mu1_sq )
sigma2_sq = compensation * ( gaussian_filter(Y * Y, win) - mu2_sq )
sigma12 = compensation * ( gaussian_filter(X * Y, win) - mu1_mu2 )
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
if size_average:
ssim_val = ssim_map.mean()
cs = cs_map.mean()
else:
ssim_val = ssim_map.mean(-1).mean(-1).mean(-1) # reduce along CHW
cs = cs_map.mean(-1).mean(-1).mean(-1)
if full:
return ssim_val, cs
else:
return ssim_val
def ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False):
if len(X.shape) != 4:
raise ValueError('Input images must 4-d tensor.')
if not X.type() == Y.type():
raise ValueError('Input images must have the same dtype.')
if not X.shape == Y.shape:
raise ValueError('Input images must have the same dimensions.')
if not (win_size % 2 == 1):
raise ValueError('Window size must be odd.')
win_sigma = win_sigma
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat(X.shape[1], 1, 1, 1)
else:
win_size = win.shape[-1]
ssim_val, cs = _ssim(X, Y,
win=win,
data_range=data_range,
size_average=False,
full=True)
if size_average:
ssim_val = ssim_val.mean()
cs = cs.mean()
if full:
return ssim_val, cs
else:
return ssim_val
def ms_ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False, weights=None):
if len(X.shape) != 4:
raise ValueError('Input images must 4-d tensor.')
if not X.type() == Y.type():
raise ValueError('Input images must have the same dtype.')
if not X.shape == Y.shape:
raise ValueError('Input images must have the same dimensions.')
if not (win_size % 2 == 1):
raise ValueError('Window size must be odd.')
if weights is None:
weights = torch.FloatTensor(
[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype)
win_sigma = win_sigma
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat(X.shape[1], 1, 1, 1)
else:
win_size = win.shape[-1]
levels = weights.shape[0]
mcs = []
for _ in range(levels):
ssim_val, cs = _ssim(X, Y,
win=win,
data_range=data_range,
size_average=False,
full=True)
mcs.append(cs)
padding = (X.shape[2] % 2, X.shape[3] % 2)
X = F.avg_pool2d(X, kernel_size=2, padding=padding)
Y = F.avg_pool2d(Y, kernel_size=2, padding=padding)
mcs = torch.stack(mcs, dim=0) # mcs, (level, batch)
# weights, (level)
msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1))
* (ssim_val ** weights[-1]), dim=0) # (batch, )
if size_average:
msssim_val = msssim_val.mean()
return msssim_val
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3):
super(SSIM, self).__init__()
self.win = _fspecial_gauss_1d(
win_size, win_sigma).repeat(channel, 1, 1, 1)
self.size_average = size_average
self.data_range = data_range
def forward(self, X, Y):
return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average)
class MS_SSIM(torch.nn.Module):
def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3, weights=None):
super(MS_SSIM, self).__init__()
self.win = _fspecial_gauss_1d(
win_size, win_sigma).repeat(channel, 1, 1, 1)
self.size_average = size_average
self.data_range = data_range
self.weights = weights
def forward(self, X, Y):
return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)
上面的工具类我在pytorch中当做损失函数使用
这里还用到几个方法,我在下面给出
import argparse
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from util import (map_range, cv2torch, random_tone_map,
DirectoryDataset, str2bool)、
#这里我把上面的ssim.py放到了一个文件夹中,所以需要这样导入
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument(
'-d', '--data_root_path_label', default='D:/project_hdr/hdr-expandnet/test_data_hdr', help='Path to hdr data.')
parser.add_argument(
'-v', '--data_root_path_pre', default='D:/project_hdr/myGan_same_size_wights_noexpand _minD/test_data_ldr/981',
help='Path to hdr data.')
parser.add_argument(
'--num_workers',
type=int,
default=1,
help='Number of data loading workers.')
return parser.parse_args()
def transformh(hdr):
hdr = map_range(hdr)
return cv2torch(hdr)
def train(opt):
# 加载训练集
dataset1 = DirectoryDataset(
data_root_path=opt.data_root_path_label, preprocess=transformh)
loader1 = DataLoader(
dataset1,
batch_size=opt.batch_size,
num_workers=opt.num_workers,)
# 加载训练集
dataset2 = DirectoryDataset(
data_root_path=opt.data_root_path_pre, preprocess=transformh)
loader2 = DataLoader(
dataset2,
batch_size=opt.batch_size,
num_workers=opt.num_workers,)
for (ldr_in, hdr_target) in zip(loader2, loader1):
if torch.cuda.is_available():
ldr_in = ldr_in.cuda()
hdr_target = hdr_target.cuda()
pre = ldr_in
real_B = hdr_target
ssim_val = ssim(real_B, pre, data_range=1, size_average=True, ) # return (N,)
ms_ssim_val = ms_ssim(real_B, pre, data_range=1, size_average=True,) # (N,)
rep = (f'ssim_val: {ssim_val},'
f'ms_ssim_val: {ms_ssim_val},')
tqdm.write(rep)
if __name__ == '__main__':
opt = parse_args()
train(opt)
下面是使用到的几个方法
def map_range(x, low=0, high=1):
return np.interp(x, [x.min(), x.max()], [low, high]).astype(x.dtype)
var foo = 'bar';
def cv2torch(np_img):
rgb = np_img[:, :, (2, 1, 0)]
return torch.from_numpy(rgb.swapaxes(1, 2).swapaxes(0, 1))
下面这个类需要opencv
class DirectoryDataset(Dataset):
def __init__(self,
data_root_path='hdr_data',
data_extensions=['.hdr', '.exr'],
load_fn=None,
preprocess=None):
super(DirectoryDataset, self).__init__()
data_root_path = process_path(data_root_path)
self.file_list = []
for root, _, fnames in sorted(os.walk(data_root_path)):
for fname in fnames:
if any(fname.lower().endswith(extension)
for extension in data_extensions):
self.file_list.append(os.path.join(root, fname))
if len(self.file_list) == 0:
msg = 'Could not find any files with extensions:\n[{0}]\nin\n{1}'
raise RuntimeError(
msg.format(', '.join(data_extensions), data_root_path))
self.preprocess = preprocess
def __getitem__(self, index):
dpoint = cv2.imread(
self.file_list[index],
flags=cv2.IMREAD_ANYDEPTH + cv2.IMREAD_COLOR)
if self.preprocess is not None:
dpoint = self.preprocess(dpoint)
return dpoint
def __len__(self):
return len(self.file_list)
def process_path(directory, create=False):
directory = os.path.expanduser(directory)
directory = os.path.normpath(directory)
directory = os.path.abspath(directory)
if create:
try:
os.makedirs(directory)
except:
pass
return directory
大概就是这些,msssim代码来源于github : https://github.com/VainF/pytorch-msssim
当然,计算ssim还有一些其他更加简便的方法,比如
from skimage.measure import compare_ssim
(score, diff) = compare_ssim(X, Y, full=True)
diff = (diff * 255).astype("float32")