阅前说明
前面已经出现的代码用 … 代替。
本文仅解析train部分的代码(inference的部分会后续更新)。
不对网络结构做过多解释,默认已经熟悉 mrcnn 的结构以及读过这篇论文了。
另:inference 部分已更新,见:siris 显著性排序网络代码解读(inference过程)
根据 README.md ,首先运行 obj_sal_seg_branch/train.py
。下面来看这个py文件的内容:
DATASET_ROOT = "D:/Desktop/ASSR/" # Change to your location
if __name__ == '__main__':
command = "train"
config = ObjSegMaskConfig()
config.display()
log_path = "logs/"
model = SOSNet(mode="training", config=config, model_dir=log_path)
首先获得一个 SOSNet
的实例对象 model
。
if __name__ == '__main__':
...
# Start from pre-trained weights
# Load weights
model_weights = "../weights/mask_rcnn_coco.h5" # Make sure this is correct or change to location of weight path
# Exclude layers - since we change the number of classes
exclude_layers = ["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"]
print("Exclude Layers: ", exclude_layers)
print("Loading weights ", model_weights)
model.load_weights(model_weights, by_name=True, exclude=exclude_layers)
设置不参与训练的层,放在 exclude_layers
中。这里,mrcnn 相关的输出是不要的。(通过后面的代码提示,因为显著性排序不需要关注显著性物体的类别,所以只分为两类:显著性物体或背景。所以注释中写“since we change the number of classes”)
随后调用了 model
的 load_weights
方法,这个方法不载入被 exclude 的层的权重。
if __name__ == '__main__':
...
if command == "train":
print("Start Training...")
# Train Dataset
dataset_train = Obj_Sal_Seg_Dataset(DATASET_ROOT, "train")
# Val Dataset
dataset_val = Obj_Sal_Seg_Dataset(DATASET_ROOT, "val")
# ********** Training **********
# Image Augmentation
# Right/Left flip 50% of the time
augmentation = imgaug.augmenters.Fliplr(0.5)
# Training - Stage 1
print("Training network heads")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=40,
layers='heads',
augmentation=augmentation)
# Training - Stage 2
# Fine tune all layers
print("Fine tune all layers")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=200,
layers='all',
augmentation=augmentation)
command == "train"
是必然的,因为一开始就赋值了。
然后调用了 obj_sal_seg_branch/Obj_Sal_Seg_Dataset
,得到训练和验证数据。数据增广后开始训练。
训练分为两个阶段,第一阶段训练网络头部,第二阶段微调所有层。
接下来看看 obj_sal_seg_branch/train.py
涉及到的两个重要类:obj_sal_seg_branch.SOSNet
和 obj_sal_seg_branch/Obj_Sal_Seg_Dataset
首先看 __init__
函数:
class SOSNet():
def __init__(self, mode, config, model_dir):
self.mode = mode
self.config = config
self.model_dir = model_dir
self.set_log_dir()
self.keras_model = Model_Sal_Seg.build_saliency_seg_model(config, mode)
前几句平平无奇,然后搞了一个 self.keras_model
,这是一个 Model 类型的对象,再点进去看,在 training 模式下,这个 model 的 input 和 output 如下:
inputs = [input_image, input_image_meta,
input_rpn_match, input_rpn_bbox, input_gt_class_ids, input_gt_boxes, input_gt_masks]
if not config.USE_RPN_ROIS:
inputs.append(input_rois)
outputs = [rpn_class_logits, rpn_class, rpn_bbox,
feat_pyr_net_class_logits, feat_pyr_net_class, feat_pyr_net_bbox, obj_seg_masks,
rpn_rois, output_rois,
rpn_class_loss, rpn_bbox_loss,
obj_sal_seg_class_loss, obj_sal_seg_bbox_loss,
obj_sal_seg_mask_loss]
没错,就是 mrcnn 的那一套,把中间结果(包括各种rois)和 loss 都当做output输出了。
这个 SOSNet
其实就是这个 model 的一个包装类,SOSNet
中的方法,一部分是为了方便训练这个model的其中一些层,包括:
还有些方法是为了方便利用训练结果:
还是从最主要的开始看,据上文所述 obj_sal_seg_branch/train
首先调用了 model.load_weights
方法。
调用时传参:
model.load_weights(model_weights, by_name=True, exclude=exclude_layers)
结合 load_weights
的源码。
首先导包:
def load_weights(self, filepath, by_name=False, exclude=None):
"""Modified version of the corresponding Keras function with
the addition of multi-GPU support and the ability to exclude
some layers from loading.
exclude: list of layer names to exclude
"""
import h5py
# Conditional import to support versions of Keras before 2.2
# TODO: remove in about 6 months (end of 2018)
try:
from keras.engine import saving
except ImportError:
# Keras before 2.2 used the 'topology' namespace.
from keras.engine import topology as saving
if h5py is None:
raise ImportError('`load_weights` requires h5py.')
然后判断是否有要剔除的层。而传参时 exclude 不为空,所以将 by_name 置为 True。(事实上在传参的时候这个参数也是True,这句话只是为了代码的稳健性。不写也没事)
def load_weights(self, filepath, by_name=False, exclude=None):
...
if exclude:
by_name = True
然后根据路径获取一个h5py文件。(对 h5py 文件不熟悉的可以参考这个博客,想知道更多的关于这个文件的操作方法请参考这个博客。)
def load_weights(self, filepath, by_name=False, exclude=None):
...
f = h5py.File(filepath, mode='r')
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
这句是为了多GUP训练
def load_weights(self, filepath, by_name=False, exclude=None):
...
# In multi-GPU training, we wrap the model. Get layers
# of the inner model because they have the weights.
keras_model = self.keras_model
layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model") \
else keras_model.layers
然后过滤掉被排除的 layers,其中 filter
是python的内置函数,第一个参数是函数,第二个参数是可迭代, filter
会把后一个参数的每个数据输入函数中判断,将其中True的结果返回。
def load_weights(self, filepath, by_name=False, exclude=None):
...
# Exclude some layers
if exclude:
layers = filter(lambda l: l.name not in exclude, layers)
然后调用 keras 的 api 载入权重、关闭 f 文件、更新日志
def load_weights(self, filepath, by_name=False, exclude=None):
...
if by_name:
saving.load_weights_from_hdf5_group_by_name(f, layers)
else:
saving.load_weights_from_hdf5_group(f, layers)
if hasattr(f, 'close'):
f.close()
# Update the log directory
self.set_log_dir(filepath)
载入权重之后,有两轮训练操作。来看看两次调用时候的传参:
# Training - Stage 1
print("Training network heads")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=40,
layers='heads',
augmentation=augmentation)
# Training - Stage 2
# Fine tune all layers
print("Fine tune all layers")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=200,
layers='all',
augmentation=augmentation)
来看这个方法的源码:
首先预定义了 layer 的正则表达式,放在一个字典里。然后利用传进来的 layer 作为 key 值,获取字典的 value。
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=None):
assert self.mode == "training", "Create model in training mode."
# TODO: Update
# Pre-defined layer regular expressions
layer_regex = {
# Only Heads
"heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
# From a specific Res-Net stage and up
"3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
"4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
"5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
# All layers
"all": ".*",
}
if layers in layer_regex.keys(): # 根据 keys 获得 value
layers = layer_regex[layers]
然后把训练数据和测试数据作为参数传入 DataGenerator.data_generator
,获得 train_generator
和 val_generator
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=None):
...
# Data generators
train_generator = DataGenerator.data_generator(train_dataset, self.config, shuffle=True,
augmentation=augmentation,
batch_size=self.config.BATCH_SIZE)
val_generator = DataGenerator.data_generator(val_dataset, self.config, shuffle=True,
batch_size=self.config.BATCH_SIZE)
logs_path = self.log_dir + "/training.log"
然后设置回调函数,对回调函数不了解的可以参考 《deep learning with python》,keras作者写的那本,第七章。
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=None):
...
# Callbacks
callbacks = [
keras.callbacks.TensorBoard(log_dir=self.log_dir,
histogram_freq=0, write_graph=True, write_images=False),
keras.callbacks.ModelCheckpoint(self.checkpoint_path,
verbose=1, save_weights_only=True),
keras.callbacks.CSVLogger(logs_path, separator=",", append=True),
]
# Add custom callbacks to the list
if custom_callbacks:
callbacks += custom_callbacks
正式开始训练。先将目标层的参数设置为可训练,然后 compile 模型,然后调用 fit 训练模型。以上都是基本流程。在这里用的 set_trainable
和 compile
是自己定义的,fit_generator
是keras的API。
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=None):
...
# Train
log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
log("Checkpoint Path: {}".format(self.checkpoint_path))
self.set_trainable(layers)
self.compile(learning_rate, self.config.LEARNING_MOMENTUM)
# Work-around for Windows: Keras fails on Windows when using
# multiprocessing workers. See discussion here:
# https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
if os.name is 'nt':
workers = 0
else:
workers = multiprocessing.cpu_count()
self.keras_model.fit_generator(
train_generator,
initial_epoch=self.epoch,
epochs=epochs,
steps_per_epoch=self.config.STEPS_PER_EPOCH,
callbacks=callbacks,
validation_data=val_generator,
validation_steps=self.config.VALIDATION_STEPS,
max_queue_size=100,
workers=workers,
use_multiprocessing=True,
)
self.epoch = max(self.epoch, epochs)
然后我们来看看自定义的两个方法:
def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1):
"""Sets model layers as trainable if their names match
the given regular expression.
"""
# Print message on the first call (but not on recursive calls)
if verbose > 0 and keras_model is None:
log("Selecting layers to train")
keras_model = keras_model or self.keras_model
# In multi-GPU training, we wrap the model. Get layers
# of the inner model because they have the weights.
layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model") \
else keras_model.layers
for layer in layers:
# Is the layer a model?
if layer.__class__.__name__ == 'Model':
print("In model: ", layer.name)
self.set_trainable(
layer_regex, keras_model=layer, indent=indent + 4)
continue
if not layer.weights:
continue
# Is it trainable?
trainable = bool(re.fullmatch(layer_regex, layer.name))
# Update layer. If layer is a container, update inner layer.
if layer.__class__.__name__ == 'TimeDistributed':
layer.layer.trainable = trainable
else:
layer.trainable = trainable
# Print trainable layer names
if trainable and verbose > 0:
log("{}{:20} ({})".format(" " * indent, layer.name,
layer.__class__.__name__))
def compile(self, learning_rate, momentum):
"""Gets the model ready for training. Adds losses, regularization, and
metrics. Then calls the Keras compile() function.
"""
# Compile SGD
optimizer = keras.optimizers.SGD(
lr=learning_rate, momentum=momentum,
clipnorm=self.config.GRADIENT_CLIP_NORM)
# Add Losses
# First, clear previously set losses to avoid duplication
self.keras_model._losses = []
self.keras_model._per_input_losses = {}
loss_names = [
"rpn_class_loss", "rpn_bbox_loss",
"obj_sal_seg_class_loss", "obj_sal_seg_bbox_loss", "obj_sal_seg_mask_loss"]
for name in loss_names:
layer = self.keras_model.get_layer(name)
if layer.output in self.keras_model.losses:
continue
loss = (tf.reduce_mean(layer.output, keepdims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.add_loss(loss)
# Add L2 Regularization
# Skip gamma and beta weights of batch normalization layers.
reg_losses = [
keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
for w in self.keras_model.trainable_weights
if 'gamma' not in w.name and 'beta' not in w.name]
self.keras_model.add_loss(tf.add_n(reg_losses))
# Compile
self.keras_model.compile(
optimizer=optimizer,
loss=[None] * len(self.keras_model.outputs))
# Add metrics for losses
for name in loss_names:
if name in self.keras_model.metrics_names:
continue
layer = self.keras_model.get_layer(name)
self.keras_model.metrics_names.append(name)
loss = (tf.reduce_mean(layer.output, keepdims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.metrics_tensors.append(loss)
好了,到这里为止,关于 README.md 中首先要运行的 obj_sal_seg_branch/train.py
部分的代码就全部解读完毕。
总结来说完成的任务是:对原 mask r-cnn 去 mrcnn_ 头,在新数据集上先对个别层训练,然后对所有层微调。
根据 README.md,第二步是运行 pre_process/pre_process_obj_feat_GT.py
文件。那么来看看它的源码吧:
首先是一堆路径的设置。这个类需要运行两次,第二次运行的时候把 data_split = "train"
注释掉,换成 data_split = "val"
。
DATASET_ROOT = "D:/Desktop/ASSR/" # Change to your location
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/" # Change to your location
if __name__ == '__main__':
# add pre-trained weight path - backbone pre-trained on salient objects (binary, no rank)
weight_path = ""
# Run Script twice to generate pre-processed object features of GT objects for "train" and "val" data_splits
data_split = "train"
# data_split = "val"
out_path = PRE_PROC_DATA_ROOT + "pre_process_feat/" + data_split + "/"
if not os.path.exists(out_path):
os.makedirs(out_path)
mode = "inference"
config = RankModelConfig()
log_path = "logs/"
然后通过 Model_Obj_Feat.build_obj_feat_model(config)
获得一个 Model 类型的实例
if __name__ == '__main__':
...
keras_model = Model_Obj_Feat.build_obj_feat_model(config)
model_name = "Obj_Feat_Net"
简单看看这个 Model 类型对象实例( keras_model
)的结构:
def build_obj_feat_model(config):
# *********************** INPUTS ***********************
input_image = Input(shape=(config.NET_IMAGE_SIZE, config.NET_IMAGE_SIZE, 3), name="input_image")
input_image_meta = Input(shape=[config.IMAGE_META_SIZE], name="input_image_meta")
input_obj_rois = Input(shape=(config.SAL_OBJ_NUM, 4), name="input_obj_rois")
# Normalize coordinates
obj_rois = Lambda(lambda x: fpn_model_utils.norm_boxes_graph(x, K.shape(input_image)[1:3]))(input_obj_rois)
# *********************** BACKBONE FEATURES ***********************
# Generate Backbone features
# backbone_feat = [P2, P3, P4, P5]
# rpn_features = [P2, P3, P4, P5, P6]
# P2: (?, 256, 256, 256)
# P3: (?, 128, 128, 256)
# P4: (?, 64, 64, 256)
# P5: (?, 32, 32, 256)
backbone_feat = generate_backbone_features(input_image, config)
P2, P3, P4, P5 = backbone_feat
# *********************** SALIENT OBJECT MASK BRANCH ***********************
# Produce Object Segment Masks
obj_seg_masks = ObjectSegmentationMaskBranch.build_fpn_mask_graph(obj_rois, backbone_feat,
input_image_meta,
config.MASK_POOL_SIZE,
config.NUM_CLASSES,
train_bn=config.TRAIN_BN)
# ROIAlign ed Object Features
obj_features = pyr_roi_align_graph(obj_rois, backbone_feat, input_image_meta,
config.POOL_SIZE,
train_bn=config.TRAIN_BN,
fc_layers_size=config.FPN_CLASSIF_FC_LAYERS_SIZE)
# Model
inputs = [input_image, input_image_meta, input_obj_rois]
outputs = [obj_seg_masks, obj_features, P5]
model = Model(inputs=inputs, outputs=outputs, name="obj_feat_model")
return model
这部分先从基础网络中获得 [P2, P3, P4, P5] 特征集合,然后将特征:
build_fpn_mask_graph
得到 mask(在代码中存放在 obj_seg_masks
变量中)。pyr_roi_align_graph
得到 ROIAlign 后的特征(在代码中存放在 obj_features
变量中 )。所以输出为: outputs = [obj_seg_masks, obj_features, P5] 。为啥要输出这仨?因为在后面的显著性排序网络中,需要用到这些特征。
然后我们再回到 pre_process/pre_process_obj_feat_GT.py 文件中,接着上面看代码:
if __name__ == '__main__':
...
model = PreProcNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)
刚看完一个 keras_model
,马上又来一个 model
,跟第一部分中 SOSNet
是 build_saliency_seg_model
build 的模型的一个包装类一样,这个 model 对应的类 PreProcNet
是 keras_model(即 build_obj_feat_model
build 的模型)的一个包装类,同样里面也装了一些方便 detect 时候调用的方法。
我们先接着把 pre_process/pre_process_obj_feat_GT.py
文件中剩余的看完,再来解读 PreProcNet
类中那些方法(不用看也可以根据前面 SOSNet
的例子猜出来,肯定也是有载入权重的方法、detect方法)
继续 pre_process/pre_process_obj_feat_GT.py :
调用model中的载入权重的方法。
if __name__ == '__main__':
...
# Load weights
print("Loading weights ", weight_path)
model.load_weights(weight_path, by_name=True)
开头就给 mode 赋值了,肯定会进入这个分支。
调用 pre_process.Dataset
获得数据,然后把 dataset 传入 pre_process.DataGenerator
。后文将对这俩类做详解。
if __name__ == '__main__':
...
if mode == "inference":
# ********** Create Datasets
# Train/Val Dataset
dataset = Dataset(DATASET_ROOT, data_split)
predictions = []
num = len(dataset.img_ids)
for i in range(num):
image_id = dataset.img_ids[i]
print(i + 1, " / ", num, " - ", image_id)
input_data, gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list = DataGenerator.load_inference_data_obj_feat_gt(dataset, image_id, config)
调用 model 中的 detect 方法(对啊,这也是为啥设置成 ‘inference’ 模式,因为这部分是预处理,得到特征方便后面的显著性排序网络。而特征的得到,是靠 detect )
if __name__ == '__main__':
...
if mode == "inference":
...
for i in range(num):
...
result = model.detect(input_data, verbose=1)
把之前从 DataGenerator 中获得的其它数据都放进 result
里面。
if __name__ == '__main__':
...
if mode == "inference":
...
for i in range(num):
...
result["gt_ranks"] = gt_ranks
result["sel_not_sal_obj_idx_list"] = sel_not_sal_obj_idx_list
result["shuffled_indices"] = shuffled_indices
result["chosen_obj_idx_order_list"] = chosen_obj_idx_order_list
最后把 result
存入本地文件。其中 pickle 是 python 中一个的工具,它能够实现任意对象与文本之间的相互转化,也可以实现任意对象与二进制之间的相互转化。也就是说,pickle 可以实现 Python 对象的存储及恢复。
if __name__ == '__main__':
...
if mode == "inference":
...
for i in range(num):
...
o_p = out_path + image_id
with open(o_p, "wb") as f:
pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
到这里 pre_process_obj_feat_GT.py 的内容就结束了,总结来说,就是获得了提取特征的model对象。
然后通过Dataset 和 DataGenerator 获得一部分数据(gt_ranks、 sel_not_sal_obj_idx_list、shuffled_indices、chosen_obj_idx_order_list),
然后调用 detect 获得另一部分数据,或者说特征(obj_masks、obj_feat、P5)。最后把这些数据都用 pickle 存到本地。
接下来返回去看看在这个过程中涉及到的 detect 方法的具体代码、以及 Dataset 和 DataGenerator 的具体代码。
load_weights
和前面基本一样,不重复说了。
看下 detect(self, input_data, verbose=0)
源码很短,真的很短。
返回一个字典,携带了预测结果,也就是:
# Detection performed per single image
def detect(self, input_data, verbose=0):
assert self.mode == "inference", "Create model in inference mode."
if verbose:
log("Processing image")
log("image", input_data[0])
detections = self.keras_model.predict(input_data, verbose=0)
# Process detection
obj_masks, obj_feat, P5 = detections
result = {}
result["obj_masks"] = obj_masks
result["obj_feat"] = obj_feat
result["P5"] = P5
return result
先看 __init__
:
class Dataset(object):
def __init__(self, dataset_root, data_split):
self.dataset_root = dataset_root # Root folder of Dataset
self.data_split = data_split
self.load_dataset()
在 __init__
中调用了 self.load_dataset()
,来看看这个方法的源码:
代码也很短,简洁明快。首先导入图片 id ,存入 self.img_ids
,然后导入排序的 ground_truth,存入 self.gt_rank_orders
。
最后导入显著性物体的分割数据,这个数据是存在 jason 文件中的。这里面可以得到 obj_bbox
, obj_seg
, _sal_obj_idx_list
, _not_sal_obj_idx_list
def load_dataset(self):
print("\nLoading Dataset...")
image_file = self.data_split + "_images.txt"
# Get list of image ids
image_path = os.path.join(self.dataset_root, image_file)
with open(image_path, "r") as f:
image_names = [line.strip() for line in f.readlines()]
self.img_ids = image_names
print(self.img_ids)
# Load Rank Order
rank_order_root = self.dataset_root + "rank_order/" + self.data_split + "/"
self.gt_rank_orders = self.load_rank_order_data(rank_order_root)
# Load Object Data
obj_seg_data_path = self.dataset_root + "obj_seg_data_" + self.data_split + ".json"
self.obj_bboxes, self.obj_seg, self.sal_obj_idx_list, self.not_sal_obj_idx_list = self.load_object_seg_data(
obj_seg_data_path)
其中详细的调用方法,比如 self.load_rank_order_data
到底是怎么 load 的就不在此详论了。
这里面有两个方法,分别是:
在 pre_process/pre_process_obj_feat_GT.py 中调用的是带 _gt 的那个,所以暂时先只讲 load_inference_data_obj_feat_gt
看源码:
先通过 dataset 获得一系列数据,包括:①image、②gt_ranks、③sel_not_sal_obj_idx_list、④shuffled_indices、⑤chosen_obj_idx_order_list、⑥object_roi_masks
def load_inference_data_obj_feat_gt(dataset, image_id, config):
image = dataset.load_image(image_id)
gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list = dataset.load_gt_rank_order(image_id)
object_roi_masks = dataset.load_object_roi_masks(image_id, sel_not_sal_obj_idx_list)
随后调用 fpn_network.utils
中的工具方法,进行一系列处理:
def load_inference_data_obj_feat_gt(dataset, image_id, config):
image = dataset.load_image(image_id)
...
original_shape = image.shape
image, window, scale, padding, crop = utils.resize_image(
image,
min_dim=config.IMAGE_MIN_DIM,
min_scale=config.IMAGE_MIN_SCALE,
max_dim=config.IMAGE_MAX_DIM,
mode=config.IMAGE_RESIZE_MODE)
obj_mask = utils.resize_mask(object_roi_masks, scale, padding, crop)
# bbox: [num_instances, (y1, x1, y2, x2)]
obj_bbox = utils.extract_bboxes(obj_mask)
chosen_obj_idx_order_list 是 dataset 返回的一个数据,这是根据 ground truth 的显著性排序选出固定个(由config决定是多少个)个体,下面这段代码把这些被选中的 sal_obj 合成 batch
def load_inference_data_obj_feat_gt(dataset, image_id, config):
image = dataset.load_image(image_id)
...
# *********************** FILL REST, SHUFFLE ORDER ***********************
# order is in salient objects then non-salient objects
batch_obj_roi = np.zeros(shape=(config.SAL_OBJ_NUM, 4), dtype=np.int32)
for i in range(len(chosen_obj_idx_order_list)):
_idx = chosen_obj_idx_order_list[i]
batch_obj_roi[_idx] = obj_bbox[i]
然后是对图片做标准化、生成 image_meta,这跟mask r-cnn里面的操作差不多。
然后就是合成 batch,返回这些数据。
def load_inference_data_obj_feat_gt(dataset, image_id, config):
image = dataset.load_image(image_id)
...
# Normalize image
image = model_utils.mold_image(image.astype(np.float32), config)
# Active classes
active_class_ids = np.ones([config.NUM_CLASSES], dtype=np.int32)
img_id = image_id
img_id = int(img_id[-12:])
# Image meta data
image_meta = model_utils.compose_image_meta(img_id, original_shape, image.shape,
window, scale, active_class_ids)
# Expand input dimensions to consider batch
image = np.expand_dims(image, axis=0)
image_meta = np.expand_dims(image_meta, axis=0)
batch_obj_roi = np.expand_dims(batch_obj_roi, axis=0)
return [image, image_meta, batch_obj_roi], gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list
到此为止,第二部分的代码就解读完成了。总结来说,第二部分的任务是提取特征以及对图片预处理,目的是方便之后输入显著性排序网络。
training 过程的最后一步是运行 train.py
文件
来看看这个文件的源码;
前面是一些设置,然后调用 Model_SAM_SMM.build_saliency_rank_model
获得 keras_model
,后文将对这个 model 的结构进行解析。然后将这个 keras_model
作为参数获得 ASRNet
的对象。显然跟前面套路一样,ASRNet
是 Model_SAM_SMM.build_saliency_rank_model
的包装类,里面还封装了一些载入权重、训练、检测之类的方法。跟前面差不多,这里就不细讲了。
# Path to dataset
DATASET_ROOT = "D:/Desktop/ASSR/" # Change to your location
# Path to pre-processed data - object features
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/" # Change to your location
if __name__ == '__main__':
weight_path = "" # add pre-trained weight path
command = "train"
config = RankModelConfig()
log_path = "logs/"
mode = "training"
print("Loading Rank Model")
keras_model = Model_SAM_SMM.build_saliency_rank_model(config, mode)
model_name = "Rank_Model_SAM_SMM"
model = ASRNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)
然后就是调用 ASRNet
里面的方法,载入权重。获得 dataset 和 dataGenerator 之后训练。逻辑都差不多。
if __name__ == '__main__':
...
# Load weights
print("Loading weights ", weight_path)
model.load_weights(weight_path, by_name=True)
# Train/Evaluate Model
if command == "train":
print("Start Training...")
# ********** Create Datasets
# Train Dataset
train_dataset = Dataset(DATASET_ROOT, PRE_PROC_DATA_ROOT, "train")
# Val Dataset
val_dataset = Dataset(DATASET_ROOT, PRE_PROC_DATA_ROOT, "val")
# ********** Parameters
# Image Augmentation
# Right/Left flip 50% of the time
# augmentation = imgaug.augmenters.Fliplr(0.5)
augmentation = None
# ********** Create Data generators
train_generator = DataGenerator.data_generator(train_dataset, config, shuffle=True,
augmentation=augmentation,
batch_size=config.BATCH_NUM)
val_generator = DataGenerator.data_generator(val_dataset, config, shuffle=True,
batch_size=config.BATCH_NUM)
# ********** Training **********
model.train(train_generator, val_generator,
learning_rate=config.LEARNING_RATE,
epochs=40,
layers='all')
下面放一下 Model_SAM_SMM.build_saliency_rank_model
得到的 model 的结构(唉,看不清的话,可以下载原图(提取码1111))
或者自己调用下面这句也生成同样的效果:
utils.plot_model(model, 'model.png', show_shapes=True)
最后的分类器的结构:
然后这部分到这里也结束了,具体的显著性分类网络的代码比较简单,就不解析了。(或许之后有时间会更新)
感觉整个网络用的 Dense 层非常多,但是一个分成 6 类(5个显著性等级和1个背景)的分类器,感觉没必要?但是我也没做实验,不晓得具体情况。
另外神经网络做排序,其实是不太合适的(参考《deep learning with python》一书的观点),在这里是把排序问题转换成了一个分类任务。虽然也可以实现目的。
以上。欢迎评论区讨论。(如果有人看的话。。。)
inference 部分已更新,见:siris 显著性排序网络代码解读(inference过程)