train_SCRN.py
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import os, argparse
from datetime import datetime
from utils.data import get_loader
from utils.func import label_edge_prediction, AvgMeter
from model.ResNet_models import SCRN
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=30, help='epoch number')
parser.add_argument('--lr', type=float, default=2e-3, help='learning rate')
parser.add_argument('--batchsize', type=int, default=8, help='batch size')
parser.add_argument('--trainsize', type=int, default=352, help='input size')
parser.add_argument('--trainset', type=str, default='DUTS-TRAIN', help='training dataset')
opt = parser.parse_args()
# data preparing, set your own data path here
data_path = './SalientObject/dataset/'
image_root = data_path + opt.trainset + '/DUTS-TR-Image/'
gt_root = data_path + opt.trainset + '/DUTS-TR-Mask/'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)
# build models
model = SCRN()
model.cuda()
params = model.parameters()
optimizer = torch.optim.SGD(params, opt.lr, momentum=0.9, weight_decay=5e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
CE = torch.nn.BCEWithLogitsLoss()
size_rates = [0.75, 1, 1.25] # multi-scale training
# training
for epoch in range(0, opt.epoch):
scheduler.step()
model.train()
loss_record1, loss_record2 = AvgMeter(), AvgMeter()
for i, pack in enumerate(train_loader, start=1):
if i % 100 == 0:
print('=>epoch:', epoch, ' |all epoch = 30', ' || =>iter:', i) # todo epoch: [%2d/%2d], iter: [%5d/%5d]
for rate in size_rates:
optimizer.zero_grad()
images, gts = pack
images = Variable(images).cuda()
gts = Variable(gts).cuda()
# edge prediction
gt_edges = label_edge_prediction(gts)
# multi-scale training samples
trainsize = int(round(opt.trainsize*rate/32)*32)
if rate != 1:
images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
gt_edges = F.upsample(gt_edges, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
# forward
pred_sal, pred_edge = model(images)
loss1 = CE(pred_sal, gts)
loss2 = CE(pred_edge, gt_edges)
loss = loss1 + loss2
loss.backward()
optimizer.step()
if rate == 1:
loss_record1.update(loss1.data, opt.batchsize)
loss_record2.update(loss2.data, opt.batchsize)
if i % 1000 == 0 or i == total_step:
print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}]'.
format(datetime.now(), epoch, opt.epoch, i, total_step))
# print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f}, Loss2: {:.4f}'.
# format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record1.show(), loss_record2.show()))
save_path = './models/'
if not os.path.exists(save_path):
os.makedirs(save_path)
torch.save(model.state_dict(), save_path + opt.trainset + '_w.pth')
test_SCRN.py
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import os
from scipy import misc
from datetime import datetime
from utils.data import test_dataset
from model.ResNet_models import SCRN
import imageio
import cv2
model = SCRN()
model.load_state_dict(torch.load('./model/model.pth'))
model.cuda()
model.eval()
# data_path = '/backup/materials/Dataset/SalientObject/dataset/'
data_path = '/home/nk/zjc/PycharmProjects/nk_dataset/'
# valset = ['ECSSD', 'HKUIS', 'PASCAL', 'DUT-OMRON', 'THUR15K', 'DUTS-TEST']
valset = ['PASCALS', 'HKU-IS', 'SOD', 'THUR15K', 'DUTS-TE']
for dataset in valset:
save_path = './saliency_maps/' + dataset + '/'
if not os.path.exists(save_path):
os.makedirs(save_path)
image_root = data_path + dataset + '/Imgs/'
# image_root = '/backup/materials/Dataset/SalientObject/dataset/' + 'ECSSD' + '/images/'
gt_root = data_path + dataset + '/GT/'
# image_root = '/backup/materials/Dataset/SalientObject/dataset/' + 'ECSSD' + '/gts/'
test_loader = test_dataset(image_root, gt_root, testsize=352)
with torch.no_grad():
for i in range(test_loader.size):
if i % 100 == 0:
print(dataset, 'is running in num', i, 'image')
image, gt, name = test_loader.load_data()
gt = np.array(gt).astype('float')
gt = gt / (gt.max() + 1e-8)
image = Variable(image).cuda()
res, edge= model(image)
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=True)
res = res.sigmoid().data.cpu().numpy().squeeze()
# print('=>res1:', res)
# print('=>res.type', type(res))
# misc.imsave(save_path + name + '.png', res)
res = 255 * res
# print('=>res2', res)
res = res.astype(np.uint8)
# print('=>res3', res)
cv2.imwrite(save_path + name + '.png', res)
# imageio.imwrite(save_path + name + '.png', res)