风格迁移的原理可以参考相关论文,有空的话会写一篇关于原理的介绍
代码来源:https://github.com/lengstrom/fast-style-transfer
style.py 是训练一种新风格时的入口文件,其功能主要是接收参数和调用训练函数。
def build_parser():
parser = ArgumentParser()
parser.add_argument('--checkpoint-dir', type=str,
dest='checkpoint_dir', help='dir to save checkpoint in',
metavar='CHECKPOINT_DIR', required=True)
parser.add_argument('--style', type=str,
dest='style', help='style image path',
metavar='STYLE', required=True)
parser.add_argument('--train-path', type=str,
dest='train_path', help='path to training images folder',
metavar='TRAIN_PATH', default=TRAIN_PATH)
parser.add_argument('--test', type=str,
dest='test', help='test image path',
metavar='TEST', default=False)
parser.add_argument('--test-dir', type=str,
dest='test_dir', help='test image save dir',
metavar='TEST_DIR', default=False)
parser.add_argument('--slow', dest='slow', action='store_true',
help='gatys\' approach (for debugging, not supported)',
default=False)
parser.add_argument('--epochs', type=int,
dest='epochs', help='num epochs',
metavar='EPOCHS', default=NUM_EPOCHS)
parser.add_argument('--batch-size', type=int,
dest='batch_size', help='batch size',
metavar='BATCH_SIZE', default=BATCH_SIZE)
parser.add_argument('--checkpoint-iterations', type=int,
dest='checkpoint_iterations', help='checkpoint frequency',
metavar='CHECKPOINT_ITERATIONS',
default=CHECKPOINT_ITERATIONS)
parser.add_argument('--vgg-path', type=str,
dest='vgg_path',
help='path to VGG19 network (default %(default)s)',
metavar='VGG_PATH', default=VGG_PATH)
parser.add_argument('--content-weight', type=float,
dest='content_weight',
help='content weight (default %(default)s)',
metavar='CONTENT_WEIGHT', default=CONTENT_WEIGHT)
parser.add_argument('--style-weight', type=float,
dest='style_weight',
help='style weight (default %(default)s)',
metavar='STYLE_WEIGHT', default=STYLE_WEIGHT)
parser.add_argument('--tv-weight', type=float,
dest='tv_weight',
help='total variation regularization weight (default %(default)s)',
metavar='TV_WEIGHT', default=TV_WEIGHT)
parser.add_argument('--learning-rate', type=float,
dest='learning_rate',
help='learning rate (default %(default)s)',
metavar='LEARNING_RATE', default=LEARNING_RATE)
return parser
变量名 | 作用 |
---|---|
checkpoint-dir | 保存训练完的模型的路径 |
style | 风格图的路径 |
train-path | 训练图的路径(COCO2017) |
test | 测试图路径 |
test-dir | 测试图文件夹 |
slow | 使用gatys的方法,debug使用 |
epochs | epochs数量 |
batch_size | batch_size 单卡20左右 |
checkpoint-iterations | 多少个step存一个快照 |
vgg-path | vgg模型文件路径 |
content-weight | 内容权重 |
style-weight | 风格权重 |
tv-weight | 总变分权重? |
learning-rate | 学习率 |
style_target = get_img(options.style)
实现
def get_img(src, img_size=False):
img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3))
if not (len(img.shape) == 3 and img.shape[2] == 3):
img = np.dstack((img,img,img))
if img_size != False:
img = scipy.misc.imresize(img, img_size)
return img
content_targets = _get_files(options.train_path)
实现:
def _get_files(img_dir):
files = list_files(img_dir)
return [os.path.join(img_dir,x) for x in files]
def list_files(in_path):
files = []
for (dirpath, dirnames, filenames) in os.walk(in_path):
files.extend(filenames)
break