之前做过超分辨率,刚好有这个比赛,拿来记录一下,截止目前初赛,score=40.22,排名46。
更新,已经复赛B轮了,目前排名24。
官方给的txt,用于获取:
y4m 格式介绍:https://wiki.multimedia.cx/index.php/YUV4MPEG2
y4m 与 yuv(yuv420 8bit planar) 互转命令:
y4mtoyuv: ffmpeg -i xx.y4m -vsync 0 xx.yuv -y
yuvtoy4m: ffmpeg -s 1920x1080 -i xx.yuv -vsync 0 xx.y4m -y
y4m 与 png 互转命令:
y4mtobmp: ffmpeg -i xx.y4m -vsync 0 xx%3d.bmp -y
bmptoy4m: ffmpeg -i xx%3d.bmp -pix_fmt yuv420p -vsync 0 xx.y4m -y
y4m 每25帧抽样命令:
ffmpeg -i xxx.y4m -vf select='not(mod(n\,25))' -vsync 0 -y xxx_sub25.y4m
## 初赛训练数据下载链接
round1_train_input:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00000_00049_l.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00050_00099_l.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00100_00149_l.zip
round1_train_label:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00000_00049_h_GT.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00050_00099_h_GT.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00100_00149_h_GT.zip
## 初赛验证数据下载链接
round1_val_input:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00150_00199_l.zip
round1_val_label:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00150_00199_h_GT.zip
## 初赛测试数据下载链接
round1_test_input:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/test/input/youku_00200_00249_l.zip
自己写了几个python脚本批量处理数据:
1.每个y4m抽取100张图片:
import os
import time
lst = os.listdir("./y4m_")
for c in lst:
if c.endswith('.y4m'):
print(c)
tmp = "ffmpeg -i "+"./y4m_/"+c+" -q:v 2 -vsync 0 ./image_bmp/"+c[:-4]+"%3d.bmp -y"
os.system(tmp)
time.sleep(5)
2.每个y4m抽取100张图片
import os
import time
lst = os.listdir("./y4m")
for c in lst:
if c.endswith('.y4m'):
tmp = "ffmpeg -i "+"./y4m/"+c+" -vf "+"\"select=not(mod(n\,25))\" "+" -vsync 0 ./image_bmp/"+c[:-4]+"%3d.bmp -y"
os.system(tmp)
time.sleep(3)
3.将图片转为所需的.y4m,图片放在image_x4下,下一级目录为要转为的.y4m名称,对应放着路径
import os
import time
for dir_ in os.listdir("./image_x4"):
tmp = "ffmpeg -i "+ "./image_x4/"+dir_+"/%3d.bmp "+"-pix_fmt yuv420p -vsync 0 "+"./result/"+dir_+".y4m -y"
os.system(tmp)
time.sleep(5)
4.批量改名
import os
for name in os.listdir("./result"):
oldname = "./result/"+name
newname ="./result/"+name[:-4]+"_h_Sub25_Res.y4m"
os.rename(oldname,newname)
print(newname)
5.先用插值算法测试:
import cv2
import os
for name in os.listdir("image"):
if not os.path.exists("./image_x4/"+name[:11]):
os.makedirs("./image_x4/"+name[:11])
image=cv2.imread("./image/"+name)
res=cv2.resize(image,(image.shape[1]*4,image.shape[0]*4),interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],res)
print(name)
else:
image=cv2.imread("./image/"+name)
res=cv2.resize(image,(image.shape[1]*4,image.shape[0]*4),interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],res)
print(name)
比赛发现的比较晚,先用插值试一下提交模型的步骤,成绩:
6.测试代码,在ESRGAN算法进行改进,损失只用MSE:
import sys
import os.path
import glob
import cv2
import numpy as np
import torch
import architecture as arch
model_path = sys.argv[1]
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
model = arch.RRDB_Net(3, 3, 64, 26, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
mode='CNA', upsample_mode='upconv')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
print('Model path {:s}. \nTesting...'.format(model_path))
for name in os.listdir("image_bmp"):
if not os.path.exists("./image_x4/"+name[:11]):
os.makedirs("./image_x4/"+name[:11])
img = cv2.imread("./image_bmp/"+name)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],output)
print(name)
else:
img = cv2.imread("./image_bmp/"+name)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],output)
print(name)
90、180、270度旋转
import scipy
from scipy import misc
import os
import time
import glob
from scipy import ndimage
def get_image_paths(folder):
return glob.glob(os.path.join(folder, '*.png'))
def create_read_img(filename):
im = misc.imread(filename)
img_rote_90 = ndimage.rotate(im, 90)
scipy.misc.imsave(filename[:-4]+'_90.png',img_rote_90)
img_rote_180 = ndimage.rotate(im, 180)
scipy.misc.imsave(filename[:-4]+'_180.png',img_rote_180)
img_rote_270 = ndimage.rotate(im, 270)
scipy.misc.imsave(filename[:-4]+'_270.png',img_rote_270)
print(filename)
img_path = '/media/wxy/000F8E4B0002F751/test/'
imgs = get_image_paths(img_path)
#print (imgs)
for i in imgs:
create_read_img(i)
镜像翻转
根据原始图像名称进行翻转
import cv2
import os
for name in os.listdir("./HR_image/"):
if len(name)==23:
image = cv2.imread("./HR_image/"+name)
h_flip = cv2.flip(image, 1) #左右
cv2.imwrite("./HR_image/"+name[:-4]+"_flip_h.png", h_flip)
w_flip = cv2.flip(image, 0) #上下
cv2.imwrite("./HR_image/"+name[:-4]+"_flip_w.png", w_flip)
print(name)
同时增强
from PIL import Image
import os
import glob
def get_image_paths(folder):
return glob.glob(os.path.join(folder, '*.png'))
def create_read_img(filename):
#读取图像
im = Image.open(filename)
out_h = im.transpose(Image.FLIP_LEFT_RIGHT)
out_w = im.transpose(Image.FLIP_TOP_BOTTOM)
out_90 = im.transpose(Image.ROTATE_90)
out_180 = im.transpose(Image.ROTATE_180)
out_270 = im.transpose(Image.ROTATE_270)
out_h.save(filename[:-4]+'_h.png')
out_w.save(filename[:-4]+'_w.png')
out_90.save(filename[:-4]+'_90.png')
out_180.save(filename[:-4]+'_180.png')
out_270.save(filename[:-4]+'_270.png')
print(filename)
img_path = '/media/wxy/000F8E4B0002F751/test/'
imgs = get_image_paths(img_path)
for i in imgs:
create_read_img(i)
多线程图像增强
import time
import threadpool
import os
from PIL import Image
name = ["/media/wxy/000F8E4B0002F751/test/"+name_ for name_ in os.listdir("./test")]
def create_read_img(filename):
# 读取图像
im = Image.open(filename)
out_h = im.transpose(Image.FLIP_LEFT_RIGHT)
out_w = im.transpose(Image.FLIP_TOP_BOTTOM)
out_90 = im.transpose(Image.ROTATE_90)
out_180 = im.transpose(Image.ROTATE_180)
out_270 = im.transpose(Image.ROTATE_270)
out_h.save(filename[:-4] + '_h.png')
out_w.save(filename[:-4] + '_w.png')
out_90.save(filename[:-4] + '_90.png')
out_180.save(filename[:-4] + '_180.png')
out_270.save(filename[:-4] + '_270.png')
print(filename)
start_time = time.time()
pool = threadpool.ThreadPool(5)
requests = threadpool.makeRequests(create_read_img, name)
[pool.putRequest(req) for req in requests]
pool.wait()
print ('%d second'% (time.time()-start_time))
import cv2
import numpy as np
import math
def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def calculate_ssim(img1, img2):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def main():
gt_img=cv2.imread("Youku_00150_h_GT001.png")
sr_img=cv2.imread("Youku_00150_l001_120000.png")
img2gray = cv2.cvtColor(sr_img, cv2.COLOR_BGR2GRAY)
img2gray_ = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY)
gt_img=gt_img/255
sr_img=sr_img/255
sr_img_y = bgr2ycbcr(sr_img, only_y=True)
gt_img_y = bgr2ycbcr(gt_img, only_y=True)
cropped_sr_img_y = sr_img_y[4:-4, 4:-4]
cropped_gt_img_y = gt_img_y[4:-4, 4:-4]
psnr_y = calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
ssim_y = calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
print(psnr_y,ssim_y)
if __name__ == '__main__':
main()