点击上方“AI搞事情”关注我们
❝论文:U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection
❞
GIT:https://github.com/NathanUA/U-2-Net
U2Net用于显著目标检测(Salient Object Detection, SOD) ,目的是分割出图像中最具吸引力的目标。不同于图像识别,SOD更注重局部细节信息和全局对比信息,而不是深层语义信息,因此,主要的研究方向在于多层次与多尺度特征提取上。
U2Net网络结构如下图,整体是一个编码-解码(Encoder-Decoder)结构的U-Net,其中,每个stage由新提出的RSU模块(residual U-block) 组成,即一个两层嵌套的U结构网络。
「其优势在于:」
1.RSU模块,融合了不同尺度感受野的特征,能够捕获更多不同尺度的上下文信息(contextual information)。
2.RSU模块的池化(pooling)操作,可以在不显著增加计算成本的情况下,加深网络结构的深度。
RSU,ReSidual Ublock, 用于捕获intra-stage的多尺度特征. 其结构如图(e)所示:
(a)-(c)显示了具有最小感受野的现有卷积块,但是1x1或者3x3的卷积核的感受野太小而无法捕捉全局信息,(d)通过利用空洞卷积增大感受野来获取全局信息,然而在前期大分辨率的输入特征图计算需要耗费大量的计算和内存资源。
残差模块与RSU模块的对比:主要设计区别在于,RSU用U-Net代替了普通的单流卷积,并用一个权重层构成的局部特征代替了原始特征:
U2Net训练损失函数定义:
其中,M=6, 为U2Net 的 Sup1, Sup2, ..., Sup6 stage, 为对应输出的显著图(saliency map) 的损失函数; 为最终融合输出的显著图 的损失函数, 为每个损失函数的权重。
对于每一项 ,使用标准二进制交叉熵来计算损失:
其中,(r,c)为像素坐标;(H, W) 为图像尺寸,height 和 width。 和 分别表示 GT 像素值和预测的显著概率图(saliency probability map)。
作者开源了代码,最近还公开了一些有趣的基于U2Net的应用,比如人像转素描,抠图、背景去除等。
我们可以根据说明进行一把尝试:
下载源码git clone https://github.com/NathanUA/U-2-Net.git
下载转素描模型:u2net_portrait.pth
放入到./saved_models/u2net_portrait/
下面。
执行脚本python u2net_portrait_test.py
程序会读取U-2-Net/test_data/test_portrait_images/portrait_im
路径下的照片进行转换,并把结果输出在U-2-Net/test_data/test_portrait_images/portrait_results
路径下。
若在CPU环境运行会提示torch.load使用参数map_location='cpu'
即:net.load_state_dict(torch.load(model_dir, map_location='cpu'))
项目也提供了任意人脸图像转换的demo,区别在于增加了opencv的人脸检测,以及裁剪到输入的512x512大小,可以通过python u2net_portrait_demo.py
执行,
图片放入路径./test_data/test_portrait_images/your_portrait_im/
结果在路径:./test_data/test_portrait_images/your_portrait_results/
通过U2Net,可以得到精细的前景alpha图像,通过简单的mask操作就可以将前景目标扣取出来。
# encoding=utf-8
import os
import cv2
import numpy as np
im1_path = '1/test.png' # 原图
im2_path = '2/test_alpha.png' # alpha图
img1 = cv2.imread(im1_path)
img2 = cv2.imread(im2_path, cv2.IMREAD_GRAYSCALE)
h, w, c = img1.shape
img3 = np.zeros((h, w, 4))
img3[:, :, 0:3] = img1
img3[:, :, 3] = img2
cv2.imwrite('res.png', img3)
有大佬将其做成了一个工具:www.remove.bg
(50次免费试用),以及还有一个python库
参考:
1. Github 项目 - U2Net 网络及实现
2. U2Net论文解读及代码测试
原文链接下载模型:0rl5
往期推荐
行千里,看山城轻轨穿楼越林;致广大,望重庆交通桥上桥下
CRNN:端到端不定长文字识别算法
DeepHSV:号称可以商用的计算机笔迹鉴别算法
python获取原图GPS位置信息,轻松得到你的活动轨迹
机器学习中有哪些距离度量方式
长按二维码关注我们
有趣的灵魂在等你