作者 | 李秋键
头图 | 下载于视觉中国
出品 | AI 科技大本营(ID:rgznai100)
引言:
视频和图像的隐身术是指在视频或者图像中中,在没有任何输入遮罩的情况下,通过框选目标体,使得程序实现自动去除视频中的文本叠加和修复被遮挡部分的问题。并且最近的基于深度学习的修复方法只处理单个图像,并且大多假设损坏像素的位置是已知的,故我们的目标是在没有蒙皮信息的视频序列中自动去除文本。
今天,我们通过搭建一个简单而有效的快速视频解码器框架去实现视频中物体的去除。流程是构建一个编码器-解码器模型,其中编码器采用多个源帧,可以提供从场景动态显示的可见像素。这些提示被聚合并输入到解码器中。然后通过应用循环反馈进一步改进加强模型。循环反馈不仅加强了时间相干性,而且提供了强大的线索。
实现效果如下可见:
模型建立
1.1 环境要求
本次环境使用的是python3.6.5+windows平台
主要用的库有:
argparse模块是python自带的命令行参数解析包,可以用来方便地读取命令行参数;
subprocess是Python 2.4中新增的一个模块,它允许你生成新的进程,连接到它们的 input/output/error 管道,并获取它们的返回(状态)码。这个模块的目的在于替换几个旧的模块和方法
numpy模块用来矩阵和数据的运算处理,其中也包括和深度学习框架之间的交互等。
torch模块是一个python优先的深度学习框架,是一个和tensorflow,Caffe,MXnet一样,非常底层的框架在这里我们用来搭建网络层和直接读取数据集操作,简单方便。
Matplotlib模块用来可视化训练效果等数据图的制作。
1.2 程序的启动
程序的运行方式如下:
1、直接运行demo.py文件对图片进行处理
2、对视频进行处理python demo.py --data data/bag.avi。
import argparse
from mask import mask
from inpaint import inpaint
parser = argparse.ArgumentParser(description='Demo')
parser.add_argument('--resume', default='cp/SiamMask_DAVIS.pth', type=str,
metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--data', default='data/Human6', help='videos or image files')
parser.add_argument('--mask-dilation', default=32, type=int, help='mask dilation when inpainting')
args = parser.parse_args()
mask(args)
inpaint(args)
1.3 算法概述
视频中物体的移除目的是从有字幕、有噪声的视频帧中预测原始帧。恢复的区域应该和原始的相同大小,或者无缝地融合到周围的像素中。基本的算法思想是从多个相邻帧(源帧)中收集提示,然后恢复目标帧。这是为了利用视频中的场景动态,在视频中,随着物体的移动或字幕的变化,被遮挡的部分通常会在滞后或引导帧中显示。同时还可以使用循环反馈连接作为额外的源流。直接估计一帧中的所有像素可能会不必要地接触到未损坏的像素。为了解决像素指标缺失的问题,采用残差学习算法对模型进行训练。具体来说,最终输出是通过按像素顺序将输入中心帧和预测残差图像相加得到的。这使得我们的网络明确地只关注损坏的像素,也防止全局色调失真。
1.4模型的搭建
模型算法核心设计是一个混合的编码器-解码器模型,其中编码器包括两个子网络:3D CNN和2D CNN。解码器遵循一个正常的2D CNN设计,该网络被设计成完全卷积的,可以处理任意大小的输入。最后的输出视频是通过自回归的方式应用函数得到的,我们的策略是从多个源帧中收集潜在的线索,这些帧可以提供从场景动态中显示的可见像素。此外,我们强制目标帧的生成与前一代保持一致。通过构造一个双流混合编码器,其中每个源流都经过训练以实现我们的目标。第一个编码器流由3D卷积组成,它可以直接从相邻帧捕获时空特征,第二个流是一个2D CNN,它将先前生成的尺寸为H×W×1×C的帧作为输入。
其中模型生成如下:
try:
assert(opt.model == 'vinet_final')
model = vinet.VINet_final(opt=opt)
except:
print('Model name should be: vinet_final')
assert(opt.no_cuda is False)
model = model.cuda()
model = nn.DataParallel(model)
loaded, empty = 0,0
if opt.pretrain_path:
print('Loading pretrained model {}'.format(opt.pretrain_path))
pretrain = torch.load(opt.pretrain_path)
child_dict = model.state_dict()
parent_list = pretrain['state_dict'].keys()
parent_dict = {}
for chi,_ in child_dict.items():
if chi in parent_list:
parent_dict[chi] = pretrain['state_dict'][chi]
#print('Loaded: ',chi)
loaded += 1
else:
#print('Empty:',chi)
empty += 1
print('Loaded: %d/%d params'%(loaded, loaded+empty))
child_dict.update(parent_dict)
model.load_state_dict(child_dict)
视频处理
2.1 预定义
我们的任务将视频去除目标后尽可能的还原成背景场景。如果场景移动或者字幕在相邻帧中消失,被遮挡的部分就会被显示出来,这就为潜在的内容提供了关键的线索。为了使增益参数的最大化,需要为我们的模型找到最佳的帧采样间隔。当最小间隔为1时,输入帧将包含不重要的动态。另一方面,如果我们以较大的步伐跳跃,不相干的新场景就会被包括进来。最终通过测试,设定的参数如下:
opt = Object()
opt.crop_size = 512
opt.double_size = True if opt.crop_size == 512 else False
########## DAVIS
DAVIS_ROOT =os.path.join('results', args.data)
DTset = DAVIS(DAVIS_ROOT, mask_dilation=args.mask_dilation, size=(opt.crop_size, opt.crop_size))
DTloader = data.DataLoader(DTset, batch_size=1, shuffle=False, num_workers=1)
opt.search_range = 4 # fixed as 4: search range for flow subnetworks
opt.pretrain_path = 'cp/save_agg_rec_512.pth'
opt.result_path = 'results/inpainting'
opt.model = 'vinet_final'
opt.batch_norm = False
opt.no_cuda = False # use GPU
opt.no_train = True
opt.test = True
opt.t_stride = 3
opt.loss_on_raw = False
opt.prev_warp = True
opt.save_image = False
opt.save_video = True
2.2 视频处理
我们的模型不仅从当前帧中收集线索,还从未来和过去相邻帧中收集线索。另外,为了保持时间一致性,有条件地生成每一帧到前一帧的输出帧。
with torch.no_grad():
for seq, (inputs, masks, info) in enumerate(DTloader):
idx = torch.LongTensor([i for i in range(pre - 1, -1, -1)])
pre_inputs = inputs[:, :, :pre].index_select(2, idx)
pre_masks = masks[:, :, :pre].index_select(2, idx)
inputs = torch.cat((pre_inputs, inputs), 2)
masks = torch.cat((pre_masks, masks), 2)
bs = inputs.size(0)
num_frames = inputs.size(2)
seq_name = info['name'][0]
save_path = os.path.join(opt.result_path, seq_name)
if not os.path.exists(save_path) and opt.save_image:
os.makedirs(save_path)
inputs = 2. * inputs - 1
inverse_masks = 1 - masks
masked_inputs = inputs.clone() * inverse_masks
masks = to_var(masks)
masked_inputs = to_var(masked_inputs)
inputs = to_var(inputs)
total_time = 0.
in_frames = []
out_frames = []
lstm_state = None
for t in range(num_frames):
masked_inputs_ = []
masks_ = []
if t < 2 * ts:
masked_inputs_.append(masked_inputs[0, :, abs(t - 2 * ts)])
masked_inputs_.append(masked_inputs[0, :, abs(t - 1 * ts)])
masked_inputs_.append(masked_inputs[0, :, t])
masked_inputs_.append(masked_inputs[0, :, t + 1 * ts])
masked_inputs_.append(masked_inputs[0, :, t + 2 * ts])
masks_.append(masks[0, :, abs(t - 2 * ts)])
masks_.append(masks[0, :, abs(t - 1 * ts)])
masks_.append(masks[0, :, t])
masks_.append(masks[0, :, t + 1 * ts])
masks_.append(masks[0, :, t + 2 * ts])
elif t > num_frames - 2 * ts - 1:
masked_inputs_.append(masked_inputs[0, :, t - 2 * ts])
masked_inputs_.append(masked_inputs[0, :, t - 1 * ts])
masked_inputs_.append(masked_inputs[0, :, t])
masked_inputs_.append(masked_inputs[0, :, -1 - abs(num_frames - 1 - t - 1 * ts)])
masked_inputs_.append(masked_inputs[0, :, -1 - abs(num_frames - 1 - t - 2 * ts)])
masks_.append(masks[0, :, t - 2 * ts])
masks_.append(masks[0, :, t - 1 * ts])
masks_.append(masks[0, :, t])
masks_.append(masks[0, :, -1 - abs(num_frames - 1 - t - 1 * ts)])
masks_.append(masks[0, :, -1 - abs(num_frames - 1 - t - 2 * ts)])
else:
masked_inputs_.append(masked_inputs[0, :, t - 2 * ts])
masked_inputs_.append(masked_inputs[0, :, t - 1 * ts])
masked_inputs_.append(masked_inputs[0, :, t])
masked_inputs_.append(masked_inputs[0, :, t + 1 * ts])
masked_inputs_.append(masked_inputs[0, :, t + 2 * ts])
masks_.append(masks[0, :, t - 2 * ts])
masks_.append(masks[0, :, t - 1 * ts])
masks_.append(masks[0, :, t])
masks_.append(masks[0, :, t + 1 * ts])
masks_.append(masks[0, :, t + 2 * ts])
masked_inputs_ = torch.stack(masked_inputs_).permute(1, 0, 2, 3).unsqueeze(0)
masks_ = torch.stack(masks_).permute(1, 0, 2, 3).unsqueeze(0)
start = time.time()
最终完成效果如下:
完整代码链接:
https://pan.baidu.com/s/1tCB0MTBbvfSokeU1AAKBQQ
提取码:nfhk
作者简介:李秋键,CSDN博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap竞赛获奖等。
你还知道哪些 Python 的新奇用法?
欢迎来评论区唠唠~
AI科技大本营将选出三名优质留言
携手【北京大学出版社】送出
《Python入门到人工智能实战》一本
截至4月18日14:00点
更多精彩推荐
无人机、IoT 设备都有漏洞?专访以色列老牌安全企业Check Point
听完姚期智的一句“嘟囔”,他开始第二次创业
AI 3D 传感器市场竞争白热化,中国掌握自主可控核心技术时不我待!
小心!你家的 IoT 设备可能已成为僵尸网络“肉鸡”
点分享点收藏点点赞点在看