pytorch官方代码:https://github.com/lcy0604/EraseNet
论文:2010.EraseNet: End-to-End Text Removal in the Wild 网盘提取码:0719
第一列原图带文字、第二列为去除后的标签,剩下的列都是不同的算法去除效果 (pix2pix, scennetextEraser ,EnsNet, 本文EraseNet)
合成的
数据集文字图片去除效果比较模型设计了一个两阶段的从粗到细的(h a two-stage ·coarse-to-refine generator network)生成器
网络和一个局部全局鉴别器
网络(a local-global discriminator network.)。(本文中作者改进了SN-GAN,并提出名为 local-global SN-Patch-GAN
的架构
一个额外的语义分割网络
头与整个算法一体的,用于感知(perceive)文字区域。
同时,借助外部预训练好的VGG-16
网络抽取特征,用来监督生成的去除文字的图片(fake samples)与标签图片(ground-truths)的高级语义的差异(discrepancies of high-level semantics.)
图8 判别器架构
单个NVIDIA 2080TI GPU
, batch size =4
用
SCUT-EnsText : 华南理工大学提出与搜集见抬头代码库
2016年提出的 Synthetic data for text localisation in natural images 用来合成数据集
自己数据集
实验结果# -*- coding: utf-8 -*-
# @Time : 2023/7/6 20:36
# @Author : XyZeng
import os
import math
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from PIL import Image
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
from data.dataloader import ErasingData,ImageTransform
from models.sa_gan import STRnet2
parser = argparse.ArgumentParser()
parser.add_argument('--numOfWorkers', type=int, default=0,
help='workers for dataloader')
parser.add_argument('--modelsSavePath', type=str, default='',
help='path for saving models')
parser.add_argument('--logPath', type=str,
default='')
parser.add_argument('--batchSize', type=int, default=16)
parser.add_argument('--loadSize', type=int, default=512,
help='image loading size')
parser.add_argument('--dataRoot', type=str,
default='./')
parser.add_argument('--pretrained',type=str, default='./model.pth', help='pretrained models for finetuning')
parser.add_argument('--savePath', type=str, default='./output')
args = parser.parse_args()
cuda = torch.cuda.is_available()
if cuda:
print('Cuda is available!')
cudnn.benchmark = True
def visual(image):
im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
batchSize = args.batchSize
loadSize = (args.loadSize, args.loadSize)
dataRoot = args.dataRoot
savePath = args.savePath
import torch.nn.functional as F
os.makedirs(savePath,exist_ok=True)
netG = STRnet2(3)
netG.load_state_dict(torch.load(args.pretrained))
if cuda:
netG = netG.cuda()
for param in netG.parameters():
param.requires_grad = False
print('OK!')
import time
start = time.time()
netG.eval()
ImgTrans=ImageTransform(args.loadSize)
def get_img_tensor(path):
img = Image.open(path)
Image.Resampling.BICUBIC (3), Image.Resampling.BOX (4) o
img=img.convert('RGB').resize((args.loadSize,args.loadSize) ,2)
inputImage = ImgTrans(img).unsqueeze(0)
# mask = ImgTrans(mask.convert('RGB'))
# inputImage = F.interpolate(inputImage, size=(512,512), mode='bilinear') # Adjust size to 115
print('inputImage',inputImage.size())
return inputImage
if __name__ == '__main__':
inpur_dir=r'example\all_images' # 改为'./你需要转换的图片目录'
for name in os.listdir(inpur_dir):
path=os.path.join(inpur_dir,name)
imgs=get_img_tensor(path)
if cuda:
imgs = imgs.cuda()
# masks = masks.cuda()
'''
看论文喝源码能发现5个输出的对应
'''
out1, out2, out3, g_images,mm = netG(imgs)
g_image = g_images.data.cpu()
mm = mm.data.cpu()
# save_image(g_image_with_mask, result_with_mask+path[0])
dir,name=os.path.split(path)
out_path=os.path.join(savePath,name)
mask_path= os.path.join(savePath,name+'_mask.png')
save_image(g_image, out_path)
save_image(mm,mask_path)
print(out_path,mask_path)
# break