run.py
# =>1 get_loader
train_loader, dataset = get_loader
# =>2 config.save_fold
config.save_fold = './EGNet/run-nnet'
# =>3 train = Solver()
train = Solver(train_loader, None, config)
# =>4 train.train()
train.train()
首先分析
# =>1 get_loader
from dataset import get_loader
dataset.py
def get_loader(batch_size, mode='train', num_thread=1, test_mode=0, sal_mode='e'):
# todo =>arg// get_loader(batch=1, test, 4, 1, e)
shuffle = False
if mode == 'train':
shuffle = True
dataset = ImageDataTrain()
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_thread)
return data_loader, dataset
分析
# =>1 get_loader中的
ImageDataTrain()读入图片:jpg,png,edge
如下,
sal_root 训练数据集位置
sal_source .lst文件
逐行读取,
计算所有数量
class ImageDataTrain(data.Dataset):
def __init__(self):
# self.sal_root = '/home/liuj/dataset/DUTS/DUTS-TR'
self.sal_root = './DUTS-TR'
# self.sal_source = '/home/liuj/dataset/DUTS/DUTS-TR/train_pair_edge.lst'
self.sal_source = './DUTS-TR/train_pair_edge.lst'
with open(self.sal_source, 'r') as f:
self.sal_list = [x.strip() for x in f.readlines()] # readline by line
self.sal_num = len(self.sal_list) # save the num of the sal in the list
self.sal_root = './DUTS-TR'
self.sal_source = './DUTS-TR/train_pair_edge.lst'
train_pair_edge.lst文件如下
DUTS-TR-Image/ILSVRC2012_test_00000018.jpg self.sal_list[item].split()[0])
DUTS-TR-Mask/ILSVRC2012_test_00000018.png self.sal_list[item].split()[1])
DUTS-TR-Mask/ILSVRC2012_test_00000018_edge.png。self.sal_list[item].split()[2])
下面是
self.sal_root + self.sal_list = './DUTS-TR/‘ + self.sal_list
# load_image ('./DUTS-TR/DUTS-TR-Image/ILSVRC2012_test_00000018.jpg')
sal_image = load_image(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[0]))
# load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png')
sal_label = load_sal_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[1]))
# load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png')
sal_edge = load_edge_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[2]))
random_flip随机反转
读入张量
sal_image = torch.Tensor(sal_image)
sal_label = torch.Tensor(sal_label)
sal_edge = torch.Tensor(sal_edge)
def __getitem__(self, item):
# load_image ('./DUTS-TR/DUTS-TR-Image/ILSVRC2012_test_00000018.jpg')
sal_image = load_image(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[0]))
# load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png')
sal_label = load_sal_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[1]))
# load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png')
sal_edge = load_edge_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[2]))
sal_image, sal_label, sal_edge = cv_random_flip(sal_image, sal_label, sal_edge) # random_flip
sal_image = torch.Tensor(sal_image)
sal_label = torch.Tensor(sal_label)
sal_edge = torch.Tensor(sal_edge)
sample = {'sal_image': sal_image, 'sal_label': sal_label, 'sal_edge': sal_edge}
return self.sal_num
# =>3 train = Solver()
train = Solver(train_loader, None, config)
=> solve.py文件
class Solver(object):
def __init__(self, train_loader, test_loader, config, save_fold=None):
self.train_loader = train_loader # input
self.config = config # input
self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255. # todo =>?//把rgb提前
self.build_model() # todo =>def//
if config.mode == 'train':
self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
self.build_model()
调用的是solve.py文件中的def build_model()函数,如下:
# build the network
def build_model(self):
self.net_bone = build_model(base_model_cfg) # from model import build_model , base_model_cfg = 'resnet'
self.net_bone.eval() # use_global_stats = True # todo =>?//
self.net_bone.apply(weights_init)
if self.config.mode == 'train': # choose//
if self.config.load_bone == '': # default=''
if base_model_cfg == 'resnet': # base_model_cfg = 'resnet'
self.net_bone.base.load_state_dict(torch.load(self.config.resnet, map_location=torch.device('cpu'))) # todo =>cuda//
self.lr_bone = p['lr_bone']
self.lr_branch = p['lr_branch']
self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone,
weight_decay=p['wd'])
'''
p['lr_bone'] = 5e-5 # Learning rate resnet:5e-5, vgg:2e-5
p['lr_branch'] = 0.025 # Learning rate
p['wd'] = 0.0005 # Weight decay
p['momentum'] = 0.90 # Momentum
'''
self.print_network(self.net_bone, 'trueUnify bone part') # todo =>def//
def build_model()函数里面的
def build_model(self):
self.net_bone = build_model(base_model_cfg)这里的build_model(base_model_cfg)是
from model import build_model
=> model.py文件
# build the whole network
def build_model(base_model_cfg='vgg'):
elif base_model_cfg == 'resnet':
return TUN_bone(base_model_cfg, *extra_layer(base_model_cfg, resnet50()))
->1
TUN_bone(base_model_cfg, *extra_layer(base_model_cfg, resnet50()))是关键函数
*extra_layer(base_model_cfg, resnet50()
应该是指的调用函数extra_layer()函数的返回值return vgg, merge1_layers, merge2_layers用于输入
->2
class TUN_bone(nn.Module):
def __init__(self, base_model_cfg, base, merge1_layers, merge2_layers): # vgg, merge1_layers, merge2_layers
super(TUN_bone, self).__init__() # use the __init__ from father class
self.base_model_cfg = base_model_cfgelif self.base_model_cfg == 'resnet':
self.convert = ConvertLayer(config_resnet['convert'])
self.base = base
self.merge1 = merge1_layers
self.merge2 = merge2_layersdef forward(self, x):
x_size = x.size()[2:]
conv2merge = self.base(x)
if self.base_model_cfg == 'resnet':
conv2merge = self.convert(conv2merge)
up_edge, edge_feature, up_sal, sal_feature = self.merge1(conv2merge, x_size)
up_sal_final = self.merge2(edge_feature, sal_feature, x_size)
return up_edge, up_sal, up_sal_final
# TUN network
class TUN_bone(nn.Module):
def __init__(self, base_model_cfg, base, merge1_layers, merge2_layers):
super(TUN_bone, self).__init__() # use the __init__ from father class
self.base_model_cfg = base_model_cfg
if self.base_model_cfg == 'vgg':
self.base = base
# self.base_ex = nn.ModuleList(base_ex)
self.merge1 = merge1_layers
self.merge2 = merge2_layers
elif self.base_model_cfg == 'resnet':
self.convert = ConvertLayer(config_resnet['convert'])
self.base = base
self.merge1 = merge1_layers
self.merge2 = merge2_layers
def forward(self, x):
x_size = x.size()[2:]
conv2merge = self.base(x)
if self.base_model_cfg == 'resnet':
conv2merge = self.convert(conv2merge)
up_edge, edge_feature, up_sal, sal_feature = self.merge1(conv2merge, x_size)
up_sal_final = self.merge2(edge_feature, sal_feature, x_size)
return up_edge, up_sal, up_sal_final
->3
def extra_layer(base_model_cfg, vgg):
# extra part
def extra_layer(base_model_cfg, vgg):
if base_model_cfg == 'vgg':
config = config_vgg
elif base_model_cfg == 'resnet':
config = config_resnet
merge1_layers = MergeLayer1(config['merge1'])
merge2_layers = MergeLayer2(config['merge2'])
return vgg, merge1_layers, merge2_layers
config_vgg = {
'convert': [[128,256,512,512,512],[64,128,256,512,512]],
'merge1': [[128, 256, 128, 3,1], [256, 512, 256, 3, 1], [512, 0, 512, 5, 2], [512, 0, 512, 5, 2],[512, 0, 512, 7, 3]],
'merge2': [[128], [256, 512, 512, 512]]}
# no convert layer, no conv6config_resnet = {
'convert': [[64,256,512,1024,2048],[128,256,512,512,512]],
'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]],
'score': 256,
'edgeinfo':[[16, 16, 16, 16], 128, [16,8,4,2]],
'edgeinfoc':[64,128],
'block': [[512, [16]], [256, [16]], [256, [16]], [128, [16]]],
'fuse': [[16, 16, 16, 16], True],
'fuse_ratio': [[16,1], [8,1], [4,1], [2,1]],
'merge1': [[128, 256, 128, 3,1], [256, 512, 256, 3, 1], [512, 0, 512, 5, 2], [512, 0, 512, 5, 2],[512, 0, 512, 7, 3]],
'merge2': [[128], [256, 512, 512, 512]]}
# =>4 train.train()
train.train()
这里主要讲了如何设计loss
train = Solver(train_loader, None, config)
class Solver(object):
def __init__(self, train_loader, test_loader, config, save_fold=None): # train = Solver(train_loader, None, config)
self.train_loader = train_loader # input
self.config = config # input
self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255. # todo =>?//把rgb提前
# inference: choose the side map (see paper)
self.build_model() # from model import build_model, weights_init
if config.mode == 'train':
self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
class Solver(object):
def __init__(self, train_loader, test_loader, config, save_fold=None): # train = Solver(train_loader, None, config)
'''
:param train_loader: train_loader
:param test_loader: None
:param config: config
:param save_fold: None
'''
self.train_loader = train_loader # input
self.config = config # input
self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255. # todo =>?//把rgb提前
# inference: choose the side map (see paper)
self.build_model() # todo =>def//
if config.mode == 'train':
self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')