test_SCR.py
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import os
from scipy import misc
from datetime import datetime
from utils.data import test_dataset
from model.ResNet_models import SCRN
import imageio
import cv2
model = SCRN()
model.load_state_dict(torch.load('./model/model.pth'))
model.cuda()
model.eval()
# data_path = '/backup/materials/Dataset/SalientObject/dataset/'
data_path = './dataset/'
# valset = ['ECSSD', 'HKUIS', 'PASCAL', 'DUT-OMRON', 'THUR15K', 'DUTS-TEST']
valset = ['DUT-OMRON']
for dataset in valset:
save_path = './saliency_maps/' + dataset + '/'
if not os.path.exists(save_path):
os.makedirs(save_path)
image_root = data_path + dataset + '/images/'
# image_root = '/backup/materials/Dataset/SalientObject/dataset/' + 'ECSSD' + '/images/'
gt_root = data_path + dataset + '/gts/'
# image_root = '/backup/materials/Dataset/SalientObject/dataset/' + 'ECSSD' + '/gts/'
test_loader = test_dataset(image_root, gt_root, testsize=352)
with torch.no_grad():
for i in range(test_loader.size):
if i % 10 == 0:
print('running: ', i)
image, gt, name = test_loader.load_data()
gt = np.array(gt).astype('float')
gt = gt / (gt.max() + 1e-8)
image = Variable(image).cuda()
res, edge= model(image)
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=True)
res = res.sigmoid().data.cpu().numpy().squeeze()
# print('=>res1:', res)
# print('=>res.type', type(res))
# misc.imsave(save_path + name + '.png', res)
res = 255 * res
# print('=>res2', res)
res = res.astype(np.uint8)
# print('=>res3', res)
cv2.imwrite(save_path + name + '.png', res)
# imageio.imwrite(save_path + name + '.png', res)
train_SRN.py
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import os, argparse
from datetime import datetime
from utils.data import get_loader
from utils.func import label_edge_prediction, AvgMeter
from model.ResNet_models import SCRN
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=30, help='epoch number')
parser.add_argument('--lr', type=float, default=2e-3, help='learning rate')
parser.add_argument('--batchsize', type=int, default=8, help='batch size')
parser.add_argument('--trainsize', type=int, default=352, help='input size')
parser.add_argument('--trainset', type=str, default='DUTS-TRAIN', help='training dataset')
opt = parser.parse_args()
# data preparing, set your own data path here
data_path = './SalientObject/dataset/'
image_root = data_path + opt.trainset + '/DUTS-TR-Image/'
gt_root = data_path + opt.trainset + '/DUTS-TR-Mask/'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)
# build models
model = SCRN()
model.cuda()
params = model.parameters()
optimizer = torch.optim.SGD(params, opt.lr, momentum=0.9, weight_decay=5e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
CE = torch.nn.BCEWithLogitsLoss()
size_rates = [0.75, 1, 1.25] # multi-scale training
# training
for epoch in range(0, opt.epoch):
scheduler.step()
model.train()
loss_record1, loss_record2 = AvgMeter(), AvgMeter()
for i, pack in enumerate(train_loader, start=1):
for rate in size_rates:
optimizer.zero_grad()
images, gts = pack
images = Variable(images).cuda()
gts = Variable(gts).cuda()
# edge prediction
gt_edges = label_edge_prediction(gts)
# multi-scale training samples
trainsize = int(round(opt.trainsize*rate/32)*32)
if rate != 1:
images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
gt_edges = F.upsample(gt_edges, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
# forward
pred_sal, pred_edge = model(images)
loss1 = CE(pred_sal, gts)
loss2 = CE(pred_edge, gt_edges)
loss = loss1 + loss2
loss.backward()
optimizer.step()
if rate == 1:
loss_record1.update(loss1.data, opt.batchsize)
loss_record2.update(loss2.data, opt.batchsize)
if i % 1000 == 0 or i == total_step:
print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f}, Loss2: {:.4f}'.
format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record1.show(), loss_record2.show()))
save_path = './models/'
if not os.path.exists(save_path):
os.makedirs(save_path)
torch.save(model.state_dict(), save_path + opt.trainset + '_w.pth')
ResNet_models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from .ResNet import ResNet50
# =>2,3 used for edge
class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super(BasicConv2d, self).__init__()
self.conv_bn = nn.Sequential(
nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes)
)
def forward(self, x):
x = self.conv_bn(x)
return x
class Reduction(nn.Module):
def __init__(self, in_channel, out_channel):
super(Reduction, self).__init__()
self.reduce = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, 3, padding=1),
BasicConv2d(out_channel, out_channel, 3, padding=1)
)
def forward(self, x):
return self.reduce(x)
# =>2 used for edge
class conv_upsample(nn.Module):
def __init__(self, channel):
super(conv_upsample, self).__init__()
self.conv = BasicConv2d(channel, channel, 1)
def forward(self, x, target):
if x.size()[2:] != target.size()[2:]:
x = self.conv(F.upsample(x, size=target.size()[2:], mode='bilinear', align_corners=True))
return x
# =>1 used for edge
class DenseFusion(nn.Module):
# Cross Refinement Unit
def __init__(self, channel):
super(DenseFusion, self).__init__()
self.conv1 = conv_upsample(channel) # =>2 used for edge
self.conv2 = conv_upsample(channel)
self.conv3 = conv_upsample(channel)
self.conv4 = conv_upsample(channel)
self.conv5 = conv_upsample(channel)
self.conv6 = conv_upsample(channel)
self.conv7 = conv_upsample(channel)
self.conv8 = conv_upsample(channel)
self.conv9 = conv_upsample(channel)
self.conv10 = conv_upsample(channel)
self.conv11 = conv_upsample(channel)
self.conv12 = conv_upsample(channel)
self.conv_f1 = nn.Sequential(
BasicConv2d(5 * channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.conv_f2 = nn.Sequential(
BasicConv2d(4 * channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.conv_f3 = nn.Sequential(
BasicConv2d(3 * channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.conv_f4 = nn.Sequential(
BasicConv2d(2 * channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.conv_f5 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.conv_f6 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.conv_f7 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
self.conv_f8 = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
BasicConv2d(channel, channel, 3, padding=1)
)
def forward(self, x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4):
# => S: f = Conv(Cat(S, Cat[CU(E)])
x_sf1 = x_s1 + self.conv_f1(torch.cat((x_s1, x_e1,
self.conv1(x_e2, x_s1), # upsample to x_s1
self.conv2(x_e3, x_s1), # upsample to x_s1
self.conv3(x_e4, x_s1)), 1)) # upsample to x_s1
x_sf2 = x_s2 + self.conv_f2(torch.cat((x_s2, x_e2,
self.conv4(x_e3, x_s2), # upsample to x_s2
self.conv5(x_e4, x_s2)), 1)) # upsample to x_s2
x_sf3 = x_s3 + self.conv_f3(torch.cat((x_s3, x_e3,
self.conv6(x_e4, x_s3)), 1)) # upsample to x_s3
x_sf4 = x_s4 + self.conv_f4(torch.cat((x_s4, x_e4), 1))
# => E: g = Conv(E x pi CU(S))
x_ef1 = x_e1 + self.conv_f5(x_e1 * x_s1 *
self.conv7(x_s2, x_e1) * # upsample to x_e1
self.conv8(x_s3, x_e1) * # upsample to x_e1
self.conv9(x_s4, x_e1)) # upsample to x_e1
x_ef2 = x_e2 + self.conv_f6(x_e2 * x_s2 *
self.conv10(x_s3, x_e2) * # upsample to x_e2
self.conv11(x_s4, x_e2)) # upsample to x_e2
x_ef3 = x_e3 + self.conv_f7(x_e3 * x_s3 *
self.conv12(x_s4, x_e3)) # upsample to x_e3
c = x_e4 * x_s4
a = x_e4
b = x_s4
x_ef4 = x_e4 + self.conv_f8(x_e4 * x_s4)
return x_sf1, x_sf2, x_sf3, x_sf4, x_ef1, x_ef2, x_ef3, x_ef4, a, b, c
class ConcatOutput(nn.Module):
def __init__(self, channel):
super(ConcatOutput, self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_cat1 = nn.Sequential(
BasicConv2d(2 * channel, 2 * channel, 3, padding=1),
BasicConv2d(2 * channel, channel, 1)
)
self.conv_cat2 = nn.Sequential(
BasicConv2d(2 * channel, 2 * channel, 3, padding=1),
BasicConv2d(2 * channel, channel, 1)
)
self.conv_cat3 = nn.Sequential(
BasicConv2d(2 * channel, 2 * channel, 3, padding=1),
BasicConv2d(2 * channel, channel, 1)
)
self.output = nn.Sequential(
BasicConv2d(channel, channel, 3, padding=1),
nn.Conv2d(channel, 1, 1)
)
def forward(self, x1, x2, x3, x4):
x3 = torch.cat((x3, self.conv_upsample1(self.upsample(x4))), 1)
x3 = self.conv_cat1(x3)
x2 = torch.cat((x2, self.conv_upsample2(self.upsample(x3))), 1)
x2 = self.conv_cat2(x2)
x1 = torch.cat((x1, self.conv_upsample3(self.upsample(x2))), 1)
x1 = self.conv_cat3(x1)
x = self.output(x1)
return x
class SCRN(nn.Module):
# Stacked Cross Refinement Network
def __init__(self, channel=32):
super(SCRN, self).__init__()
self.resnet = ResNet50()
# get sal feature
self.reduce_s1 = Reduction(256, channel)
self.reduce_s2 = Reduction(512, channel)
self.reduce_s3 = Reduction(1024, channel)
self.reduce_s4 = Reduction(2048, channel)
# get edge feature
self.reduce_e1 = Reduction(256, channel)
self.reduce_e2 = Reduction(512, channel)
self.reduce_e3 = Reduction(1024, channel)
self.reduce_e4 = Reduction(2048, channel)
# CRUs
self.df1 = DenseFusion(channel)
self.df2 = DenseFusion(channel)
self.df3 = DenseFusion(channel)
self.df4 = DenseFusion(channel)
self.output_s = ConcatOutput(channel)
self.output_e = ConcatOutput(channel)
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(std=0.01)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
self.initialize_weights()
def forward(self, x):
size = x.size()[2:]
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x1 = self.resnet.layer1(x)
x2 = self.resnet.layer2(x1)
x3 = self.resnet.layer3(x2)
x4 = self.resnet.layer4(x3)
# feature abstraction
x_s1 = self.reduce_s1(x1)
x_s2 = self.reduce_s2(x2)
x_s3 = self.reduce_s3(x3)
x_s4 = self.reduce_s4(x4)
x_e1 = self.reduce_e1(x1)
x_e2 = self.reduce_e2(x2)
x_e3 = self.reduce_e3(x3)
x_e4 = self.reduce_e4(x4)
# four cross refinement units
x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4, a1, b1, c1 = self.df1(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4)
x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4, a2, b2, c2 = self.df2(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4)
x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4, a3, b3, c3 = self.df3(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4)
x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4, a4, b4, c4 = self.df4(x_s1, x_s2, x_s3, x_s4, x_e1, x_e2, x_e3, x_e4)
# feature aggregation using u-net
pred_s = self.output_s(x_s1, x_s2, x_s3, x_s4)
pred_e = self.output_e(x_e1, x_e2, x_e3, x_e4)
pred_s = F.upsample(pred_s, size=size, mode='bilinear', align_corners=True)
pred_e = F.upsample(pred_e, size=size, mode='bilinear', align_corners=True)
return pred_s, pred_e
def initialize_weights(self):
res50 = models.resnet50(pretrained=True)
self.resnet.load_state_dict(res50.state_dict(), False)