个人看过觉得比较合适的代码部分记录于此,以后一些部分的代码抄就完事了。随缘更新
Code:https://github.com/swz30/MPRNet
2022/6/5:图像增强方面的论文,输入数据都是图像格式。代码简洁明了,针对其他任务则添加关于模型路径的参数,修改模型读取处的代码;有需要则添加计算运行时间函数,输出运行平均时间。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
import os
from runpy import run_path
from skimage import img_as_ubyte
from collections import OrderedDict
from natsort import natsorted
from glob import glob
import cv2
import argparse
# 输入路径,输出路径,所选择任务
parser = argparse.ArgumentParser(description='Demo MPRNet')
parser.add_argument('--input_dir', default='./samples/input/', type=str, help='Input images')
parser.add_argument('--result_dir', default='./samples/output/', type=str, help='Directory for results')
parser.add_argument('--task', required=True, type=str, help='Task to run', choices=['Deblurring', 'Denoising', 'Deraining'])
args = parser.parse_args()
def save_img(filepath, img):
# 保存图像
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def load_checkpoint(model, weights):
# 加载权重
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
task = args.task
inp_dir = args.input_dir
out_dir = args.result_dir
# exist_ok=True,则目标路径存在不会触发异常
os.makedirs(out_dir, exist_ok=True)
# 自然排序,后缀相同时自然排序
files = natsorted(glob(os.path.join(inp_dir, '*.jpg'))
+ glob(os.path.join(inp_dir, '*.JPG'))
+ glob(os.path.join(inp_dir, '*.png'))
+ glob(os.path.join(inp_dir, '*.PNG')))
# 找不到文件抛出错误
if len(files) == 0:
raise Exception(f"No files found at {inp_dir}")
# Load corresponding model architecture and weights
load_file = run_path(os.path.join(task, "MPRNet.py"))
# 把MPRNet.py文件中的MPRNet类取出来了
model = load_file['MPRNet']()
model.cuda()
# 读取参数路径,并设为eval
weights = os.path.join(task, "pretrained_models", "model_"+task.lower()+".pth")
load_checkpoint(model, weights)
model.eval()
# 因为网络结构中使用了U-Net结构,所以对于输入都要进行调整
img_multiple_of = 8
for file_ in files:
img = Image.open(file_).convert('RGB')
input_ = TF.to_tensor(img).unsqueeze(0).cuda()
# Pad the input if not_multiple_of 8
# 如果输入图像不满足8的倍数则padding
h,w = input_.shape[2], input_.shape[3]
H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
padh = H-h if h%img_multiple_of!=0 else 0
padw = W-w if w%img_multiple_of!=0 else 0
input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
# 得到输出,限制其范围[min,max]
with torch.no_grad():
restored = model(input_)
restored = restored[0] # 有多个阶段的输出,0号位置的是最终输出
restored = torch.clamp(restored, 0, 1)
# Unpad the output
# 输出的图像又去掉padding的部分
restored = restored[:,:,:h,:w]
# [B,H,W,C]
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
# 将图像转换为8位无符号整数格式
restored = img_as_ubyte(restored[0])
# 先拆开前面的目录取出文件名,然后去除文件格式
# "A/B/c.jpg"-->"c.jpg"-->"c"
f = os.path.splitext(os.path.split(file_)[-1])[0]
# 保存文件
save_img((os.path.join(out_dir, f+'.png')), restored)
# 输出表示结束
print(f"Files saved at {out_dir}")
Code:https://github.com/zhilin007/FFA-Net
2022/6/6:FFA-Net是一篇图像去雾方向的论文,这篇论文曾经复现并进行改进,虽然在RESIDE数据集上有着非常好的指标,但在真实师姐场景几乎是不起作用的,或者说只有对近景的薄雾有一丢丢作用。代码的命名略显粗糙,tensorShow值得Copy一下。
import os,argparse
import numpy as np
from PIL import Image
from models import *
import torch
import torch.nn as nn
import torchvision.transforms as tfs
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
# 当前工作目录绝对路径
abs=os.getcwd()+'/'
# 输入输出图像展示
def tensorShow(tensors,titles=['haze']):
fig=plt.figure()
for tensor,tit,i in zip(tensors,titles,range(len(tensors))):
img = make_grid(tensor)
npimg = img.numpy()
ax = fig.add_subplot(221+i)
ax.imshow(np.transpose(npimg, (1, 2, 0)))
ax.set_title(tit)
plt.show()
# 参数设置
parser=argparse.ArgumentParser()
parser.add_argument('--task',type=str,default='its',help='its or ots')
parser.add_argument('--test_imgs',type=str,default='test_imgs',help='Test imgs folder')
opt=parser.parse_args()
dataset=opt.task
# 模型的超参数
gps=3
blocks=19
# 测试图像路径和输出路径
img_dir=abs+opt.test_imgs+'/'
output_dir=abs+f'pred_FFA_{dataset}/'
print("pred_dir:",output_dir)
# 创建输出路径
if not os.path.exists(output_dir):
os.mkdir(output_dir)
# 模型参数路径
model_dir=abs+f'trained_models/{dataset}_train_ffa_{gps}_{blocks}.pk'
# 加载模型,加载参数,调整为eval()
device='cuda' if torch.cuda.is_available() else 'cpu'
ckp=torch.load(model_dir,map_location=device)
net=FFA(gps=gps,blocks=blocks)
net=nn.DataParallel(net)
net.load_state_dict(ckp['model'])
net.eval()
# 读取路径中的图片
for im in os.listdir(img_dir):
print(f'\r {im}',end='',flush=True)
haze = Image.open(img_dir+im)
haze1= tfs.Compose([
tfs.ToTensor(),
tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])
])(haze)[None,::]
haze_no=tfs.ToTensor()(haze)[None,::]
with torch.no_grad():
pred = net(haze1)
ts=torch.squeeze(pred.clamp(0,1).cpu())
# 每一对输入输出图像都会展示出来
tensorShow([haze_no,pred.clamp(0,1).cpu()],['haze','pred'])
# "A/B/c.jpg"-->"A/B/c"-->"A/B/c_FFA.png"
vutils.save_image(ts,output_dir+im.split('.')[0]+'_FFA.png')
Code:https://github.com/EmilienDupont/coin
# Based on https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/utils/plot/__main__.py
import imageio
import json5 as json
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from matplotlib import cm
from pathlib import Path
ours = 'COIN'
# Ensure consistent coloring across plots
# 为每个折线选择不同的颜色
name_to_color = {
ours: mcolors.TABLEAU_COLORS['tab:blue'],
'BMS': mcolors.TABLEAU_COLORS['tab:orange'],
'MBT': mcolors.TABLEAU_COLORS['tab:green'],
'CST': mcolors.TABLEAU_COLORS['tab:red'],
'JPEG': mcolors.TABLEAU_COLORS['tab:purple'],
'JPEG2000': mcolors.TABLEAU_COLORS['tab:brown'],
'BPG': mcolors.TABLEAU_COLORS['tab:pink'],
'VTM': mcolors.TABLEAU_COLORS['tab:gray'],
}
# Setup colormap for residuals plot
#
viridis = cm.get_cmap('viridis', 100)
def parse_json_file(filepath, metric='psnr'):
"""Parses a json result file.
Args:
filepath (string): Path to results json file.
metric (string): Metric to use for plot.
"""
# 路径
filepath = Path(filepath)
name = filepath.name.split('.')[0]
# 读取
with filepath.open('r') as f:
try:
data = json.load(f)
except json.decoder.JSONDecodeError as err:
print(f'Error reading file {filepath}')
raise err
# 确保json文件格式没问题,这些文件都能再coin-main/baselines里面找到
if 'results' not in data or 'bpp' not in data['results']:
raise ValueError(f'Invalid file {filepath}')
# 也是确保格式没问题
if metric not in data['results']:
raise ValueError(
f'Error: metric {metric} not available.'
f' Available metrics: {", ".join(data["results"].keys())}'
)
#查看ms-ssim的话转成db
if metric == 'ms-ssim':
# Convert to db
values = np.array(data['results'][metric])
data['results'][metric] = -10 * np.log10(1 - values)
# 返回name,bpp和metric
return {
'name': data.get('name', name),
'xs': data['results']['bpp'],
'ys': data['results'][metric],
}
def rate_distortion(scatters, title=None, ylabel='PSNR [dB]', output_file=None,
limits=None, show=False, figsize=None):
"""Creates a rate distortion plot based on scatters.
Args:
scatters (list of dicts): List of data to plot for each model.
title (string):
ylabel (string):
output_file (string): If not None, save plot at output_file.
limits (tuple of ints):
show (bool): If True shows plot.
figsize (tuple of ints):
"""
if figsize is None:
figsize = (7, 4)
fig, ax = plt.subplots(figsize=figsize)
for sc in scatters:
if sc['name'] == ours:
linewidth = 2.5
markersize = 10
else:
linewidth = 1
markersize = 6
if sc['name'] in [ours, 'BMS', 'MBT', 'CST']:
pattern = '.-' # Learned algorithms
else:
pattern = '.--' # Non learned algorithms
ax.plot(sc['xs'], sc['ys'], pattern, label=sc['name'],
c=name_to_color[sc['name']], linewidth=linewidth,
markersize=markersize)
ax.set_xlabel('Bit-rate [bpp]')
ax.set_ylabel(ylabel)
ax.grid()
if limits is not None:
ax.axis(limits)
ax.legend(loc='lower right')
if title:
ax.title.set_text(title)
if show:
plt.show()
if output_file:
fig.savefig(output_file, dpi=300, bbox_inches='tight')
plt.clf()
plt.close()
def plot_rate_distortion(filepaths=['results.json',
'baselines/compressai-bmshj2018-hyperprior.json',
'baselines/compressai-mbt2018.json',
'baselines/compressai-cheng2020-anchor.json',
'baselines/jpeg.json', 'baselines/jpeg2000.json',
'baselines/bpg_444_x265_ycbcr.json',
'baselines/vtm.json'],
output_file=None, limits=None):
"""Creates rate distortion plot based on all results json files.
Args:
filepaths (list of string): List of paths to result json files.
output_file (string): Path to save image.
limits (tuple of float): Limits of plot.
"""
# Read data
scatters = []
for f in filepaths:
rv = parse_json_file(f, 'psnr')
scatters.append(rv)
# Create plot
rate_distortion(scatters, output_file=output_file, limits=limits)
def plot_model_size(output_file=None, show=False):
"""Plots histogram of model sizes.
Args:
output_file (string): If not None, save plot at output_file.
show (bool): If True shows plot.
Notes:
Data for all baselines was computed using the compressAI library
https://github.com/InterDigitalInc/CompressAI
"""
model_names = ['COIN', 'BMS', 'MBT', 'CST']
model_sizes = [14.7455, 10135.868, 24764.604, 31834.464] # in kB
plt.grid(zorder=0, which="both", axis="y") # Ensure grid is at the back
barplot = plt.bar(model_names, model_sizes, log=True, zorder=10)
for i in range(len(model_names)):
barplot[i].set_color(name_to_color[model_names[i]])
plt.ylabel("Model size [kB]")
fig = plt.gcf()
fig.set_size_inches(3, 4)
if show:
plt.show()
if output_file:
plt.savefig(output_file, format='png', dpi=400, bbox_inches='tight')
plt.clf()
plt.close()
def plot_residuals(path_original='kodak-dataset/kodim15.png',
path_coin='imgs/kodim15_coin_bpp_03.png',
path_jpeg='imgs/kodim15_jpeg_bpp_03.jpg',
output_file=None, show=False, max_residual=0.3,
title_fontsize=6):
"""Creates a plot comparing compression with COIN and JPEG both in terms of
the compressed image and the residual between the compressed and original
image.
Args:
path_original (string): Path to original image.
path_coin (string): Path to image compressed with COIN.
path_jpeg (string): Path to image compressed with JPEG.
output_file (string): If not None, save plot at output_file.
show (bool): If True shows plot.
max_residual (float): Value between 0 and 1 to use for maximum residual
on color scale. Usually set to a low value so residuals are clearer
on plot.
"""
# Load images and compute residuals
img_original = imageio.imread(path_original) / 255.
img_coin = imageio.imread(path_coin) / 255.
img_jpeg = imageio.imread(path_jpeg) / 255.
residual_coin = viridis(np.abs(img_coin - img_original).mean(axis=-1) / max_residual)[:, :, :3]
residual_jpeg = viridis(np.abs(img_jpeg - img_original).mean(axis=-1) / max_residual)[:, :, :3]
# Create plot
plt.subplot(2, 3, 1)
plt.imshow(img_original)
plt.axis('off')
plt.gca().set_title('Original', fontsize=title_fontsize)
plt.subplot(2, 3, 2)
plt.imshow(img_coin)
plt.axis('off')
plt.gca().set_title('COIN', fontsize=title_fontsize)
plt.subplot(2, 3, 3)
plt.imshow(residual_coin)
plt.axis('off')
plt.gca().set_title('COIN Residual', fontsize=title_fontsize)
plt.subplot(2, 3, 5)
plt.imshow(img_jpeg)
plt.axis('off')
plt.gca().set_title('JPEG', fontsize=title_fontsize)
plt.subplot(2, 3, 6)
plt.imshow(residual_jpeg)
plt.axis('off')
plt.gca().set_title('JPEG Residual', fontsize=title_fontsize)
plt.subplots_adjust(wspace=0.1, hspace=0)
if show:
plt.show()
if output_file:
plt.savefig(output_file, dpi=300, bbox_inches='tight')
plt.clf()
plt.close()
if __name__ == '__main__':
plot_rate_distortion(output_file='rate_distortion.png',
limits=(0, 1, 22, 38))
plot_model_size(output_file='model_sizes.png')
plot_residuals(output_file='residuals_kodim15_bpp_03.png')
plot_residuals(output_file='residuals_kodim15_bpp_015.png',
path_coin='imgs/kodim15_coin_bpp_015.png',
path_jpeg='imgs/kodim15_jpeg_bpp_015.jpg')