前段时间参加了一个关于图像视频检索的比赛,抽空总结一下思路,并在结尾附上参赛代码以及对应数据集。
链接:媒体融合创新创意大赛 比赛主链接 复赛名单链接
截至目前只公布了Top10,复赛参加后暂无更多的消息,可能主办方也鸽了(不是)。
结果出来了,被榜上的大佬们反复蹂躏,榜单在这。
提供数据集为:
包含3186个长度为10秒左右的随机视频,模拟视频数据库。
从db中截取的图片,并进行相应处理,处理方式包括:
原始视频中某一帧,用于图像查询视频,变换类型说明如下:
- black_pad: 基于原图片上下添加黑边
- bw: 颜色变换-黑白
- color: 颜色变换-彩色
- crop: 图片裁减
- cut: 子片段(对于图片无效)
- logo: 随机位置增加logo
- mohu: 变得模糊
- ori: 无变换,可以理解为原切图
- shuiyin: 增加全屏的水印
- zimu: 下方增加字幕
分为训练集(每种类1000个图片)和测试集(每种类200个图片)
基于db视频采取与image相同处理方式的视频,训练集与测试集相同。
这里不做比赛简介的复述了,采用大白话的方式进行说明:
设计一种算法,官方测试时指定新的db文件夹,以及随机混合10种类型的图片和视频文件夹,以路径作为参数传入指定脚本中,经过处理运算后得到一个匹配对照文件以及耗时文件。运行时提供必要的显卡和cuda环境作为支持。为避免环境问题,参赛方以docker镜像的方式提供算法模型。
代码按照要求应分为三个py文件,build.py、query_image.py、query_video.py
主要难度:
混合类型的图像或音频,算法泛用性难度大大增加;官方测试时仅一张T4显卡作为运算主力,耗时问题;还有一个CV和深度学习处理能力为0的本人(我是菜狗)。
既然提供了显卡和cuda环境,本能的想到了什么目标识别、聚类、相似度匹配巴拉巴拉的一大堆,感觉要狠狠的torch起来,但是查阅了很多论文和开源项目代码,发现我能找到的或者我能理解并复现的一些技术,并不适合当前数据集的要求,进而在选择技术方法方面花费了较长时间。
最终采取了图像相似度的方式,也就是基于图像hash值对比的方式进行图片检索,在本次数据集的数量级范围内理论可行。Imagehash官方的库中有各种hash值计算的函数,包括但不限于ahash、dhash、phash、whash等,其中我们根据实验数据进行测试,排除了ahash、dhash等简单数值运算的方式,选取了whash基于小波变换的方式。
whash算法核心思路:
小波变换中会进行多次迭代,对whash原理想要详细了解的同学可以看以下的源码,但是感觉一些底层函数在这部分源码中也没有讲清楚,这里更像是判断并计算,感兴趣的小伙伴们可以深入了解一下。
def whash(image, hash_size = 8, image_scale = None, mode = 'haar', remove_max_haar_ll = True):
"""
Wavelet Hash computation.
based on https://www.kaggle.com/c/avito-duplicate-ads-detection/
@image must be a PIL instance.
@hash_size must be a power of 2 and less than @image_scale.
@image_scale must be power of 2 and less than image size. By default is equal to max
power of 2 for an input image.
@mode (see modes in pywt library):
'haar' - Haar wavelets, by default
'db4' - Daubechies wavelets
@remove_max_haar_ll - remove the lowest low level (LL) frequency using Haar wavelet.
"""
import pywt
if image_scale is not None:
assert image_scale & (image_scale - 1) == 0, "image_scale is not power of 2"
else:
image_natural_scale = 2**int(numpy.log2(min(image.size)))
image_scale = max(image_natural_scale, hash_size)
ll_max_level = int(numpy.log2(image_scale))
level = int(numpy.log2(hash_size))
assert hash_size & (hash_size-1) == 0, "hash_size is not power of 2"
assert level <= ll_max_level, "hash_size in a wrong range"
dwt_level = ll_max_level - level
image = image.convert("L").resize((image_scale, image_scale), Image.ANTIALIAS)
pixels = numpy.asarray(image) / 255.
# Remove low level frequency LL(max_ll) if @remove_max_haar_ll using haar filter
if remove_max_haar_ll:
coeffs = pywt.wavedec2(pixels, 'haar', level = ll_max_level)
coeffs = list(coeffs)
coeffs[0] *= 0
pixels = pywt.waverec2(coeffs, 'haar')
# Use LL(K) as freq, where K is log2(@hash_size)
coeffs = pywt.wavedec2(pixels, mode, level = dwt_level)
dwt_low = coeffs[0]
# Substract median and compute hash
med = numpy.median(dwt_low)
diff = dwt_low > med
return ImageHash(diff)
我们现在手里的武器:
掌握了一种叫Whash的函数,把图片传入函数,返回一个可以描述一张图片的“特征码”,凭借对比“特征码”的异同之处来判断是不是同一张图片或者相似图片。
我们需要解决的问题:
给你一个图片或者视频,说出这是数据库的哪一个视频。
将db数据库的视频切帧,每隔10帧左右保存一张图片,然后根据图片生成对应的“特征码”,全部处理后,我们会获得3186个视频中出现过的事物的全部信息,也就是“特征码”,大概60000个“特征码”,采用十六进制存储为csv格式,大小为13M左右。这部分预处理交给build.py去做
判断图片来自哪个视频:
加载处理过的60000个特征码,把目标图片传入Whash函数,获得特征码,进行比较,选取差别(汉明距最小)的目标“特征值”,返回该“特征值”来自的db数据视频id。
判断视频是db中哪个视频:
比图片多一步预处理,将待查找视频切取某一帧,后续简化为图片查找。
1)不同种类的泛用性
测试时需要传入各种处理过的不同类型,尽管Whash的鲁棒性较强,抗干扰能力也还可以,但是某些类型,如“黑边”“裁切”图片准确率极低。
2)遍历的复杂度
查询每一张图片都需要进行遍历获得所有汉明距离,而“特征码”不同点的随机性,以及“非存在性”使得索引或者传统查找方式难以使用。
“非存在性”人话版:
假设待查找特征码为10001,我们当然想要寻找标号为10001的特征码,然而因为切帧并不一定就正好切到原图,就算切到原图,经过变换后也基本不可能是10001。
假设我们站在上帝视角,最相近的图片特征码为“10004”,汉名距离为3,事实便是:我们要查的数据不存在数据库中,我们只能找最相近的某一个值,而这个“最”怎么避免遍历查找便是问题,而且不同图片查找后汉明距离这一数据不能得到有效利用,只能废弃。
“特征码不同点的随机性”人话版:
此时我们想到能否让数据库数据排好队,我们从第一位赶着来,但是问题又来了,这个图片特征码为10001,而上帝视角的我们又知道了,最相似的图片是20101,距离为3。我们不知道哪一位是变的,因此从哪一位开始对比似乎都显得不合适。
经过小组讨论(一共就两个人),我们探讨了似乎可行的方案,但并未实施落地,原理如下
PS:1、采用size=16的whash值长度位256位0或1,但存储时采用16进制。
2、汉明距离大于20基本不可能是原图。
我们不关心0和1的具体位置,允许你是0我是1的情况存在,只不过这样我们的“汉明距离”便会增加1。基于以上理论,我们将256位数据分段,统计每段0和1的个数,比如说分了四段,1的个数分别为 20,38,54,19(0~64取值),现在我们拿到了一个带查找特征码,他的分段数据为:17,12,52,22。那我们便可以根据第二段差距过大的原因,直接否决他们是原图的可能性。
有同学可能要问了,这不还是比较吗,但请注意,分段数据是可以排序的,我们建立类似树状的结构,1,1,1,1->1,1,1,2---->>>>64,64,64,64
拿到一个图片,我们计算每段01个数,假如为30,30,30,30,那我们进行遍历查找的数据可能只需要是20,20,20,20到40,40,40,40这部分的数据。
但是在实际操作中还是存在一定的技术难度,比如构建数据结构,每段容错只给10是否合理,截至提交时依旧采用遍历方式(太菜了)。
3)该方式无法使用GPU加速
暂时没有研究过“汉明距离”遍历计算以及生成hash值如何使用gpu加速的问题。
1. 非敏感类型
WHash 算法对黑白、颜色反转、Logo、模糊、水印、字幕这几类不敏感,原因如下:黑白、颜色反转因为统一转换为灰度图的原因,颜色信息会被丢弃,因此来自颜色的干扰不会产生影响;
Logo、模糊、水印因为在获取低频信息以及迭代时会丢弃掉这些影响,放大了部分低频信息;字幕会对图片特征值获取产生部分影响,但是从实际结果来看,海明距仅增加 1 到 2,可以忽略。
2. 敏感类型
只需要处理 Crop,black-pad 两种类型的图片或视频,思路步骤如下:
1. 黑边类型
扫描黑边采用像素扫描的方式,从左上和左下的第一个像素点进行扫描,并向中心渐进。默认存在黑边,记录黑边的“宽度”,直到发现某一行出现大量非黑色像素点停止记录,观察“黑边”宽度(即已经扫描过的行数),若小于某临界值,则认为不存在黑边。若在临界值之上,则判定存在黑边。
黑边图片可以直接根据记录的“黑边宽度”,使用 OpenCV 的裁切函数删除黑边区域。
2. 裁切类型
这部分处理花费了很多心思,困难有二:
1)主要难以判断是否是裁切类图片或者视频;
2)裁切视频或图片该如何在3000个视频正确匹配,基于像素遍历扫描还是基于深度学习?
当时并没有很好的方案能够进行适配。但注意到所有的测试集与 DB 样本集中,除了竖屏视频,其余视频样本长宽比均维持在1280*720 类似的比例,裁切类型则会生成比例不定的任意类型图片,且大小不会过小,借助这一特点, 可以尝试判断图片比例,按照图片比例进行筛选,从而分离出裁切类型。
裁切图片需要进行补全操作 OpenCV 库中提供了这几个常见的补全方式:
根据扩充效果,首先排除 CONSTANT 方法,单一的色块堆叠(无论什么颜色)会对 Whash 算法造成巨 大误差(先前去除黑边就是为了避免该误差)。
接下来则是运行测试样例,统计不同扩充方式对图片匹配准确率提升的大小。测试时挑选了原本无法正确匹配的裁剪类图片100 张,分别使用 REPLICATE、REFLECT、REFLECT_101、WRAP 四种算法进行边界填充,REPLICATE 算法表现最好,再次成功识别了42张(裁切类型直接匹配准确率不足30%,这还是裁切部分不算太多的情况下,相当惨淡),其余算法均在35张左右。
因为详细代码过长,此处不进行全部粘贴,文末会附上代码和数据集的下载链接
def cut_video(file_path, save_path):
"""
指定视频文件目录 file_path
指定存放目录 save_path
功能:根据指定目录,获取视频并生成同名文件夹,文件夹下存放切帧图片
file_path/test_video.mp4 --> save_path/test_video/1.jpg--n.jpg
切图文件夹名会与视频文件夹名保持一致 切图名由数字 1 开始递增
"""
pass
def get_hash(main_path, save_path):
"""
指定数据库的切图路径,生成hash值,并在指定位置保存为csv文件
变量说明:
关于HASH
hash针对的是一张图片,在本函数中,三个hash变量所代表的内容如下:
all_hash:文件夹下所有视频的所有帧的hash
per_hash:某个视频的所有帧的 hash
pic_hash:某个视频的某一帧的 hash
即多个pic_hash组成per_hash,多个per_hash组成all_hash
关于 PATH
结构:main_path/video_1/1.jpg
main_path: 总路径,包含诸如 video_1、video_2等文件夹
per_path: 精确到 video_1文件夹
pic_names: 精确到video_1的 jpg们
"""
pass
build.py运行时需要传入db文件夹位置,并指定cache文件夹,用来存放切图和hash值文件
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--db', type=str, help='请输入文件夹名称')
parser.add_argument('--cache', type=str, help='请输入文件夹名称')
args = parser.parse_args()
video_cut_path = os.path.join(args.cache, 'video_cut')
hash_data_path = os.path.join(args.cache, 'hash_data')
cut_video(args.db, video_cut_path)
get_hash(video_cut_path, hash_data_path)
print('处理完毕,请运行query_image.py与query_video.py')
return
def auto_core(img_dir_name, rows):
"""
此为 auto_detection()的核心匹配函数,部分变量释义请参考该函数
为了便于计算时间且使代码较为清晰易读,我们进行了核心函数分离
"""
pass
def auto_detection(full_path, result_path, csv_file):
"""
思路:
根据build.py 中生成 csv文件中的 hash值,从full_path中读取图片,
匹配获得最相似的hash值,将全部结果输出在result_path目录中,会生成两个
文件,依次是匹配对照表result.csv以及time_cost.txt
target_list: 存储匹配到的文件名列表 (即我认为与 3 号文件匹配,这里就会存储 3)
dir_name_list: 指定匹配的图片名列表
all_dist: 每个匹配项的海明距离列表
因为在数据处理时,维持文件名不变,所以在csv中观察名称是否一一对应即可判断是否匹配成功
"""
auto_core(img_dir_name, rows)
pass
def deal_bp_and_shape(full_pic_path, write_path):
"""
因比赛方案中存在两种难以处理图片
crop裁切类图片:hash值会因为裁切图片位移错位而失去准确度,
black_padding附加黑边类图片:黑边会在hash值中生成过多代表黑色的数值导致汉明距离异常增大
共同特点:图片长宽异常,长宽像素点个数与db数据库中过的图片不一致
处理方式:
黑边类:
自上而下自左而右扫描图片,遇到非纯黑点停止,判断已扫描的行数,大于定值即视为黑边图
因为扫描时存储了已扫描行数,上下对称切除黑边即可
裁切类:
考虑到判断是否为裁切较为困难,出于节约时间的考虑,根据比赛数据,我们采用了直接判断图像像素大小的方式
但是该方式缺点不言而喻,希望后续可以改进
工作流程:
1、在judge_pic()函数中判断是否为长宽比异常图片,因为在本次数据集中,长宽比异常时必定属于以上两类类型
2、异常图片传入本函数deal_bp_and_shape(),先扫描黑边,如判定存在黑边,切除黑边进行下一张判断
判定无黑边,则使用cv2.copyMakeBorder()扩充方式补全四周,补全后大小为标准图片大小
实践证明:对于本次数据集,对称补全后的图片准确率上升20%以上,不失为一种处理方式,但不具备推广性。
"""
pass
def judge_pic(image_path, img_pre_path):
"""
处理方式以及原理参考deal_bp_and_shape()函数备注
"""
deal_bp_and_shape(full_pic_path, write_path)
pass
query_image.py执行时需要传入先前的cache参数,指定结果输出文件位置result,指定待查找图片位置img
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--cache', type=str, help='请输入文件夹名称')
parser.add_argument('--result', type=str, help='请输入文件夹名称')
parser.add_argument('--img', type=str, help='请输入文件夹名称')
args = parser.parse_args()
hash_data_path = os.path.join(args.cache, 'hash_data', 'data.csv')
img_pre_path = os.path.join(args.cache, 'img_pre_path')
img_res_path = os.path.join(args.result, 'img_res_path')
time_start = time.time() # 计时开始
print('计时开始,正在预处理请稍后')
judge_pic(args.img, img_pre_path)
auto_detection(img_pre_path, img_res_path, hash_data_path)
time_end = time.time() # 计时结束
# 写入txt文件
time_cost = time_end - time_start
with open(os.path.join(img_res_path, 'time_cost.txt'), 'w', encoding='gbk') as f:
f.write('所在文件夹所有图片匹配总耗时' + str(round(time_cost * 1000, 2)) + '毫秒')
print('已写入', os.path.join(img_res_path, 'time_cost.txt'), '文件,编码集为gbk')
print('匹配完毕,请前往', img_res_path, '文件夹查看')
return
与query_image.py极为相似,多出一步待测视频切帧的步骤:
def cut_video_for_once(file_path, save_path):
"""
视频检索核心与图片检索相同,每个视频切帧一张
即化简为图片检索思路
"""
pass
jupyter或linux请执行:
python build.py --db /data/db --cache /data/cache cache
文件夹会自动创建,事先手动创建也可,下同
jupyter或linux请执行:
python query_image.py --cache /data/cache --result /data/result --img /data/image
jupyter或linux请执行:
python query_video.py --cache /data/cache --result /data/result --img /data/video
此时正在切图
第一行的预处理可能会持续稍久一点,因为此时会判断是否存在黑边和裁切类型,并且尝试除去黑边和补全裁切。
此时正在视频切帧
此时正在匹配
数据集分为:
代码分为:
额外提供一个创建数据库及对应image、video的代码方式,首先随意手动创建一个db,然后填入3186个视频的总数据地址,会根据名称自动建库,代码不在赘述,相信应该较为浅显易懂。
额外提供一个检测准确率的代码方式,指定csv文件地址,调整好正确的列名,即可计算准确率。
代码及数据集链接:喜欢本文的请点赞收藏一波吧
提取码:3b51