本文主要介绍采用python中的cv2模块实现图片上色。
首先需要下载预训练的model。
1、创建getModels.sh文件,并输出如下内容:
mkdir models
wget https://github.com/richzhang/colorization/blob/master/colorization/resources/pts_in_hull.npy?raw=true -O ./pts_in_hull.npy
wget https://raw.githubusercontent.com/richzhang/colorization/master/colorization/models/colorization_deploy_v2.prototxt -O ./models/colorization_deploy_v2.prototxt
wget http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/colorization_release_v2.caffemodel -O ./models/colorization_release_v2.caffemodel
wget http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/colorization_release_v2_norebal.caffemodel -O ./models/colorization_release_v2_norebal.caffemodel
2、
chmod +x getModels.sh
3、
sh getModels.sh
其次,运行如下代码:
import numpy as np
import cv2
import os.path
# 读取图片
frame = cv2.imread('./greyscaleImage.png')
# 指定model所在路径
protoFile = './models/colorization_deploy_v2.prototxt'
weightFile = './models/colorization_release_v2.caffemodel'
# 加载聚类中心
pts_in_hull = np.load('./pts_in_hull.npy')
# 读取网络
net = cv2.dnn.readNetFromCaffe(protoFile, weightFile)
# 将聚类中心填充为1x1卷积核
pts_in_hull = pts_in_hull.transpose().reshape(2, 313, 1, 1)
net.getLayer(net.getLayerId('class8_ab')).blobs = [pts_in_hull.astype(np.float32)]
net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, np.float32)]
W_in = 224
H_in = 224
img_rgb = (frame[:, :, [2, 1, 0]] * 1.0 / 255).astype(np.float32)
img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2Lab)
img_l = img_lab[:, :, 0] # 提取出L通道
# 将L通道的图像重新设置大小为network的输出大小
img_l_rs = cv2.resize(img_l, (W_in, H_in))
img_l_rs -= 50 # 从mean-centering减去50
net.setInput(cv2.dnn.blobFromImage(img_l_rs))
ab_dec = net.forward()[0, :, :, :].transpose((1, 2, 0)) # 结果
(H_orig, W_orig) = img_rgb.shape[:2] # 原始的图片大小
ab_dec_us = cv2.resize(ab_dec, (W_orig, H_orig))
img_lab_out = np.concatenate((img_l[:, :, np.newaxis], ab_dec_us), axis=2) # 与原始图片L连接
img_bgr_out = np.clip(cv2.cvtColor(img_lab_out, cv2.COLOR_Lab2BGR), 0, 1)
outputFile = 'colorized.png'
cv2.imwrite(outputFile, (img_bgr_out * 255).astype(np.uint8))
print('Done!')
测试图片:
测试结果:
最后,附上github链接。