本篇基于SIGGRAPH 2017 (ACM ToG)的 Globally and Locally Consistent Image Completion
(CE中加入Global+Local两个判别器的改进),
proj:http://hi.cs.waseda.ac.jp/~iizuka/projects/completion/
Github代码:
1)https://github.com/satoshiiizuka/siggraph2017_inpaintinggithub.com
2)https://github.com/shinseung428/GlobalLocalImageCompletion_TF
其中第二个实现稍微不同于原论文。但是展示效果非常棒。第一个是官方代码。
因此,我这边主要以2)中的代码解析为例。先看看readme.
Tensorflow implementation of Globally and Locally Consistent Image Completion on celebA dataset.
因此数据集采用的是celebA自行下载即可。当然也可以自己准备数据集,后面时间充足的情况下,我准备利用亚洲人脸重新训练此模型。
-data
-img_align_celeba
-img1.jpg
-img2.jpg
-...
$ python train.py
To continue training
$ python train.py --continue_training=True
Download pretrained weights
$ python download.py
$ python test.py --img_path=./data/test/test_img.jpg
简单如上,dataset直接解压后放到指定目录,比如我的直接放到了
/home/gavin/Dataset/,那么训练的时候加上参数即可:
python3 train.py --continue_training=True --data /home/gavin/Dataset/
由于原版是Python2 版本,可能有些写法需要修改,我这边小修小改已经改成了python3版本,
训练截图如下:
最后是test,代码如下:
import tensorflow as tf
import numpy as np
from config import *
from network import *
drawing = False # true if mouse is pressed
ix,iy = -1,-1
color = (255,255,255)
size = 10
def erase_img(args, img):
# mouse callback function
def erase_rect(event,x,y,flags,param):
global ix,iy,drawing
if event == cv2.EVENT_LBUTTONDOWN:
drawing = True
if drawing == True:
# cv2.circle(img,(x,y),10,(255,255,255),-1)
cv2.rectangle(img,(x-size,y-size),(x+size,y+size),color,-1)
cv2.rectangle(mask,(x-size,y-size),(x+size,y+size),color,-1)
elif event == cv2.EVENT_MOUSEMOVE:
if drawing == True:
# cv2.circle(img,(x,y),10,(255,255,255),-1)
cv2.rectangle(img,(x-size,y-size),(x+size,y+size),color,-1)
cv2.rectangle(mask,(x-size,y-size),(x+size,y+size),color,-1)
elif event == cv2.EVENT_LBUTTONUP:
drawing = False
# cv2.circle(img,(x,y),10,(255,255,255),-1)
cv2.rectangle(img,(x-size,y-size),(x+size,y+size),color,-1)
cv2.rectangle(mask,(x-size,y-size),(x+size,y+size),color,-1)
cv2.namedWindow('image')
cv2.setMouseCallback('image',erase_rect)
#cv2.namedWindow('mask')
#cv2.setMouseCallback('mask',erase_rect)
mask = np.zeros(img.shape)
while(1):
img_show = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
cv2.imshow('image',img_show)
k = cv2.waitKey(1) & 0xFF
if k == 27:#esc ord('q')
break
test_img = cv2.resize(img, (args.input_height, args.input_width))/127.5 - 1
test_mask = cv2.resize(mask, (args.input_height, args.input_width))/255.0
#fill mask region to 1
test_img = (test_img * (1-test_mask)) + test_mask
cv2.destroyAllWindows()
return np.tile(test_img[np.newaxis,...], [args.batch_size,1,1,1]), np.tile(test_mask[np.newaxis,...], [args.batch_size,1,1,1])
def test(args, sess, model):
#saver
saver = tf.train.Saver()
last_ckpt = tf.train.latest_checkpoint(args.checkpoints_path)
saver.restore(sess, last_ckpt)
ckpt_name = str(last_ckpt)
print("Loaded model file from " + ckpt_name)
img = cv2.imread(args.img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
orig_test = cv2.resize(img, (args.input_height, args.input_width))/127.5 - 1
orig_test = np.tile(orig_test[np.newaxis,...],[args.batch_size,1,1,1])
orig_test = orig_test.astype(np.float32)
orig_w, orig_h = img.shape[0], img.shape[1]
test_img, mask = erase_img(args, img)
test_img = test_img.astype(np.float32)
print("Testing ...")
res_img = sess.run(model.test_res_imgs, feed_dict={model.single_orig:orig_test,
model.single_test:test_img,
model.single_mask:mask})
orig = cv2.resize((orig_test[0]+1)/2, (orig_h//2, orig_w//2) )
test = cv2.resize((test_img[0]+1)/2, (orig_h//2, orig_w//2))
recon = cv2.resize((res_img[0]+1)/2, (orig_h//2, orig_w//2))
res = np.hstack([orig,test,recon])
res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
'''
orig = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
cv2.imshow("orig", orig)
test = cv2.cvtColor(test, cv2.COLOR_BGR2RGB)
cv2.imshow("test", test)
recon = cv2.cvtColor(recon, cv2.COLOR_BGR2RGB)
cv2.imshow("recon", recon)
'''
cv2.imshow("result", res)
cv2.waitKey()
cv2.destroyAllWindows()
print("Done.")
def main(_):
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth = True
with tf.Session(config=run_config) as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
model = network(args)
print('Start Testing...')
test(args, sess, model)
main(args)
效果展示:
运行后,出现原图,然后鼠标点击可以涂mask,最后按住Esc完成操作,程序自动进行修复。