SCRN :主要文件(跑起来需要做修改,如下是修改好了的)

train_SCRN.py

import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.autograd import Variable

import numpy as np
import os, argparse
from datetime import datetime

from utils.data import get_loader
from utils.func import label_edge_prediction, AvgMeter

from model.ResNet_models import SCRN

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=30, help='epoch number')
parser.add_argument('--lr', type=float, default=2e-3, help='learning rate')
parser.add_argument('--batchsize', type=int, default=8, help='batch size')
parser.add_argument('--trainsize', type=int, default=352, help='input size')
parser.add_argument('--trainset', type=str, default='DUTS-TRAIN', help='training  dataset')
opt = parser.parse_args()

# data preparing, set your own data path here
data_path = './SalientObject/dataset/'
image_root = data_path + opt.trainset + '/DUTS-TR-Image/'
gt_root = data_path + opt.trainset + '/DUTS-TR-Mask/'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)

# build models
model = SCRN()
model.cuda()
params = model.parameters()
optimizer = torch.optim.SGD(params, opt.lr, momentum=0.9, weight_decay=5e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
CE = torch.nn.BCEWithLogitsLoss()
size_rates = [0.75, 1, 1.25]  # multi-scale training

# training
for epoch in range(0, opt.epoch):

    scheduler.step()
    model.train()
    loss_record1, loss_record2 = AvgMeter(), AvgMeter()
    for i, pack in enumerate(train_loader, start=1):
        if i % 100 == 0:
            print('=>epoch:', epoch, ' |all epoch = 30', '   ||    =>iter:', i)  # todo epoch: [%2d/%2d], iter: [%5d/%5d]
        for rate in size_rates:
            optimizer.zero_grad()

            images, gts = pack
            images = Variable(images).cuda()
            gts = Variable(gts).cuda()
            # edge prediction
            gt_edges = label_edge_prediction(gts)

            # multi-scale training samples
            trainsize = int(round(opt.trainsize*rate/32)*32)
            if rate != 1:
                images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
                gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
                gt_edges = F.upsample(gt_edges, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
            # forward
            pred_sal, pred_edge = model(images)

            loss1 = CE(pred_sal, gts)
            loss2 = CE(pred_edge, gt_edges)
            loss = loss1 + loss2 
            loss.backward()

            optimizer.step()
            if rate == 1:
                loss_record1.update(loss1.data, opt.batchsize)
                loss_record2.update(loss2.data, opt.batchsize)

        if i % 1000 == 0 or i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}]'.
                format(datetime.now(), epoch, opt.epoch, i, total_step))
            # print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f}, Loss2: {:.4f}'.
            #      format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record1.show(), loss_record2.show()))

save_path = './models/'
if not os.path.exists(save_path):
    os.makedirs(save_path)
torch.save(model.state_dict(), save_path + opt.trainset + '_w.pth')

test_SCRN.py
 

import torch
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import os
from scipy import misc
from datetime import datetime

from utils.data import test_dataset
from model.ResNet_models import SCRN

import imageio
import cv2

model = SCRN()
model.load_state_dict(torch.load('./model/model.pth'))
model.cuda()
model.eval()

# data_path = '/backup/materials/Dataset/SalientObject/dataset/'
data_path = '/home/nk/zjc/PycharmProjects/nk_dataset/'
# valset = ['ECSSD', 'HKUIS', 'PASCAL', 'DUT-OMRON', 'THUR15K', 'DUTS-TEST']
valset = ['PASCALS', 'HKU-IS', 'SOD', 'THUR15K', 'DUTS-TE']
for dataset in valset:
    save_path = './saliency_maps/' + dataset + '/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    image_root = data_path + dataset + '/Imgs/'
    # image_root = '/backup/materials/Dataset/SalientObject/dataset/' + 'ECSSD' + '/images/'
    gt_root = data_path + dataset + '/GT/'
    # image_root = '/backup/materials/Dataset/SalientObject/dataset/' + 'ECSSD' + '/gts/'
    test_loader = test_dataset(image_root, gt_root, testsize=352)

    with torch.no_grad():
        for i in range(test_loader.size):
            if i % 100 == 0:
                print(dataset, 'is running in num', i, 'image')
            image, gt, name = test_loader.load_data()
            gt = np.array(gt).astype('float')
            gt = gt / (gt.max() + 1e-8)
            image = Variable(image).cuda()

            res, edge= model(image)

            res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=True)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            # print('=>res1:', res)
            # print('=>res.type', type(res))
            # misc.imsave(save_path + name + '.png', res)
            res = 255 * res
            # print('=>res2', res)
            res = res.astype(np.uint8)
            # print('=>res3', res)
            cv2.imwrite(save_path + name + '.png', res)
            # imageio.imwrite(save_path + name + '.png', res)



 

 

 

 

 

 

 

 

你可能感兴趣的:(SCRN中科院)