siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency

阅前说明
前面已经出现的代码用 … 代替。
本文仅解析train部分的代码(inference的部分会后续更新)。
不对网络结构做过多解释,默认已经熟悉 mrcnn 的结构以及读过这篇论文了。

另:inference 部分已更新,见:siris 显著性排序网络代码解读(inference过程)

文章目录

  • 第一部分 训练mrcnn网络
    • obj_sal_seg_branch/train.py
    • obj_sal_seg_branch.SOSNet
      • model.load_weights
      • model.train 及该方法中调用的其他方法
  • 第二部分 预处理 提取特征
    • pre_process/pre_process_obj_feat_GT.py
    • pre_process.PreProcNet
      • model.detect
    • pre_process.Dataset
    • pre_process.DataGenerator
  • 第三部分 训练显著性排序网络
  • 最后的一点思考

第一部分 训练mrcnn网络

siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency_第1张图片

obj_sal_seg_branch/train.py

根据 README.md ,首先运行 obj_sal_seg_branch/train.py 。下面来看这个py文件的内容:

DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location

if __name__ == '__main__':
    command = "train"

    config = ObjSegMaskConfig()
    config.display()
    log_path = "logs/"

    model = SOSNet(mode="training", config=config, model_dir=log_path)

首先获得一个 SOSNet 的实例对象 model

if __name__ == '__main__':
    ...
    # Start from pre-trained weights
    # Load weights
    model_weights = "../weights/mask_rcnn_coco.h5"  # Make sure this is correct or change to location of weight path

    # Exclude layers - since we change the number of classes
    exclude_layers = ["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"]
    print("Exclude Layers: ", exclude_layers)

    print("Loading weights ", model_weights)
    model.load_weights(model_weights, by_name=True, exclude=exclude_layers)

设置不参与训练的层,放在 exclude_layers 中。这里,mrcnn 相关的输出是不要的。(通过后面的代码提示,因为显著性排序不需要关注显著性物体的类别,所以只分为两类:显著性物体或背景。所以注释中写“since we change the number of classes”)

随后调用了 modelload_weights 方法,这个方法不载入被 exclude 的层的权重。

if __name__ == '__main__':
    ...
    if command == "train":
        print("Start Training...")

        # Train Dataset
        dataset_train = Obj_Sal_Seg_Dataset(DATASET_ROOT, "train")

        # Val Dataset
        dataset_val = Obj_Sal_Seg_Dataset(DATASET_ROOT, "val")

        # ********** Training  **********
        # Image Augmentation
        # Right/Left flip 50% of the time
        augmentation = imgaug.augmenters.Fliplr(0.5)

        # Training - Stage 1
        print("Training network heads")
        model.train(dataset_train, dataset_val,
                    learning_rate=config.LEARNING_RATE,
                    epochs=40,
                    layers='heads',
                    augmentation=augmentation)

        # Training - Stage 2
        # Fine tune all layers
        print("Fine tune all layers")
        model.train(dataset_train, dataset_val,
                    learning_rate=config.LEARNING_RATE / 10,
                    epochs=200,
                    layers='all',
                    augmentation=augmentation)

command == "train" 是必然的,因为一开始就赋值了。
然后调用了 obj_sal_seg_branch/Obj_Sal_Seg_Dataset ,得到训练和验证数据。数据增广后开始训练。
训练分为两个阶段,第一阶段训练网络头部,第二阶段微调所有层。

接下来看看 obj_sal_seg_branch/train.py 涉及到的两个重要类:obj_sal_seg_branch.SOSNetobj_sal_seg_branch/Obj_Sal_Seg_Dataset

obj_sal_seg_branch.SOSNet

首先看 __init__ 函数:

class SOSNet():
    def __init__(self, mode, config, model_dir):
        self.mode = mode
        self.config = config
        self.model_dir = model_dir
        self.set_log_dir()

        self.keras_model = Model_Sal_Seg.build_saliency_seg_model(config, mode)

前几句平平无奇,然后搞了一个 self.keras_model ,这是一个 Model 类型的对象,再点进去看,在 training 模式下,这个 model 的 input 和 output 如下:

inputs = [input_image, input_image_meta,
                  input_rpn_match, input_rpn_bbox, input_gt_class_ids, input_gt_boxes, input_gt_masks]
if not config.USE_RPN_ROIS:
    inputs.append(input_rois)
    
outputs = [rpn_class_logits, rpn_class, rpn_bbox,
           feat_pyr_net_class_logits, feat_pyr_net_class, feat_pyr_net_bbox, obj_seg_masks,
           rpn_rois, output_rois,
           rpn_class_loss, rpn_bbox_loss,
           obj_sal_seg_class_loss, obj_sal_seg_bbox_loss,
           obj_sal_seg_mask_loss]

没错,就是 mrcnn 的那一套,把中间结果(包括各种rois)和 loss 都当做output输出了。

这个 SOSNet 其实就是这个 model 的一个包装类,SOSNet 中的方法,一部分是为了方便训练这个model的其中一些层,包括:

  • train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
    augmentation=None, custom_callbacks=None)
  • set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1)
  • compile(self, learning_rate, momentum)
  • load_weights(self, filepath, by_name=False, exclude=None)

还有些方法是为了方便利用训练结果:

  • detect(self, images, verbose=0)
  • mold_inputs(self, images)
  • unmold_detections(self, detections, mrcnn_mask, original_image_shape, image_shape, window)
  • get_anchors(self, image_shape)

model.load_weights

还是从最主要的开始看,据上文所述 obj_sal_seg_branch/train 首先调用了 model.load_weights 方法。

调用时传参:

model.load_weights(model_weights, by_name=True, exclude=exclude_layers)

结合 load_weights 的源码。
首先导包:

def load_weights(self, filepath, by_name=False, exclude=None):
    """Modified version of the corresponding Keras function with
    the addition of multi-GPU support and the ability to exclude
    some layers from loading.
    exclude: list of layer names to exclude
    """
    import h5py
    # Conditional import to support versions of Keras before 2.2
    # TODO: remove in about 6 months (end of 2018)
    try:
        from keras.engine import saving
    except ImportError:
        # Keras before 2.2 used the 'topology' namespace.
        from keras.engine import topology as saving
   
    if h5py is None:
        raise ImportError('`load_weights` requires h5py.')

然后判断是否有要剔除的层。而传参时 exclude 不为空,所以将 by_name 置为 True。(事实上在传参的时候这个参数也是True,这句话只是为了代码的稳健性。不写也没事)

def load_weights(self, filepath, by_name=False, exclude=None):
    ...
    if exclude:
        by_name = True

然后根据路径获取一个h5py文件。(对 h5py 文件不熟悉的可以参考这个博客,想知道更多的关于这个文件的操作方法请参考这个博客。)

def load_weights(self, filepath, by_name=False, exclude=None):
    ...
    f = h5py.File(filepath, mode='r')
    if 'layer_names' not in f.attrs and 'model_weights' in f:
        f = f['model_weights']

这句是为了多GUP训练

def load_weights(self, filepath, by_name=False, exclude=None):
    ...
    # In multi-GPU training, we wrap the model. Get layers
    # of the inner model because they have the weights.
    keras_model = self.keras_model
    layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model") \
        else keras_model.layers

然后过滤掉被排除的 layers,其中 filter 是python的内置函数,第一个参数是函数,第二个参数是可迭代, filter 会把后一个参数的每个数据输入函数中判断,将其中True的结果返回。

def load_weights(self, filepath, by_name=False, exclude=None):
    ...
    # Exclude some layers
    if exclude:
        layers = filter(lambda l: l.name not in exclude, layers)

然后调用 keras 的 api 载入权重、关闭 f 文件、更新日志

def load_weights(self, filepath, by_name=False, exclude=None):
    ...
	if by_name:
        saving.load_weights_from_hdf5_group_by_name(f, layers)
    else:
        saving.load_weights_from_hdf5_group(f, layers)

    if hasattr(f, 'close'):
        f.close()

    # Update the log directory
    self.set_log_dir(filepath)

model.train 及该方法中调用的其他方法

载入权重之后,有两轮训练操作。来看看两次调用时候的传参:

# Training - Stage 1
        print("Training network heads")
        model.train(dataset_train, dataset_val,
                    learning_rate=config.LEARNING_RATE,
                    epochs=40,
                    layers='heads',
                    augmentation=augmentation)

        # Training - Stage 2
        # Fine tune all layers
        print("Fine tune all layers")
        model.train(dataset_train, dataset_val,
                    learning_rate=config.LEARNING_RATE / 10,
                    epochs=200,
                    layers='all',
                    augmentation=augmentation)

来看这个方法的源码:

首先预定义了 layer 的正则表达式,放在一个字典里。然后利用传进来的 layer 作为 key 值,获取字典的 value。

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
              augmentation=None, custom_callbacks=None):
    assert self.mode == "training", "Create model in training mode."

    # TODO: Update
    # Pre-defined layer regular expressions
    layer_regex = {
        # Only Heads
        "heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        # From a specific Res-Net stage and up
        "3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        "4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        "5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        # All layers
        "all": ".*",
    }
    if layers in layer_regex.keys():  # 根据 keys 获得 value
        layers = layer_regex[layers]

然后把训练数据和测试数据作为参数传入 DataGenerator.data_generator ,获得 train_generatorval_generator

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
              augmentation=None, custom_callbacks=None):
    ...
    # Data generators
    train_generator = DataGenerator.data_generator(train_dataset, self.config, shuffle=True,
                                                   augmentation=augmentation,
                                                   batch_size=self.config.BATCH_SIZE)
    val_generator = DataGenerator.data_generator(val_dataset, self.config, shuffle=True,
                                                 batch_size=self.config.BATCH_SIZE)

    logs_path = self.log_dir + "/training.log"

然后设置回调函数,对回调函数不了解的可以参考 《deep learning with python》,keras作者写的那本,第七章。

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
              augmentation=None, custom_callbacks=None):
    ...
    # Callbacks
    callbacks = [
        keras.callbacks.TensorBoard(log_dir=self.log_dir,
                                    histogram_freq=0, write_graph=True, write_images=False),
        keras.callbacks.ModelCheckpoint(self.checkpoint_path,
                                        verbose=1, save_weights_only=True),
        keras.callbacks.CSVLogger(logs_path, separator=",", append=True),
    ]

    # Add custom callbacks to the list
    if custom_callbacks:
        callbacks += custom_callbacks

正式开始训练。先将目标层的参数设置为可训练,然后 compile 模型,然后调用 fit 训练模型。以上都是基本流程。在这里用的 set_trainablecompile 是自己定义的,fit_generator 是keras的API。

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
              augmentation=None, custom_callbacks=None):
    ...
    # Train
    log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
    log("Checkpoint Path: {}".format(self.checkpoint_path))
    self.set_trainable(layers)
    self.compile(learning_rate, self.config.LEARNING_MOMENTUM)

    # Work-around for Windows: Keras fails on Windows when using
    # multiprocessing workers. See discussion here:
    # https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
    if os.name is 'nt':
        workers = 0
    else:
        workers = multiprocessing.cpu_count()

    self.keras_model.fit_generator(
        train_generator,
        initial_epoch=self.epoch,
        epochs=epochs,
        steps_per_epoch=self.config.STEPS_PER_EPOCH,
        callbacks=callbacks,
        validation_data=val_generator,
        validation_steps=self.config.VALIDATION_STEPS,
        max_queue_size=100,
        workers=workers,
        use_multiprocessing=True,
    )
    self.epoch = max(self.epoch, epochs)

然后我们来看看自定义的两个方法:

def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1):
    """Sets model layers as trainable if their names match
    the given regular expression.
    """
    # Print message on the first call (but not on recursive calls)
    if verbose > 0 and keras_model is None:
        log("Selecting layers to train")

    keras_model = keras_model or self.keras_model

    # In multi-GPU training, we wrap the model. Get layers
    # of the inner model because they have the weights.
    layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model") \
        else keras_model.layers

    for layer in layers:
        # Is the layer a model?
        if layer.__class__.__name__ == 'Model':
            print("In model: ", layer.name)
            self.set_trainable(
                layer_regex, keras_model=layer, indent=indent + 4)
            continue

        if not layer.weights:
            continue
        # Is it trainable?
        trainable = bool(re.fullmatch(layer_regex, layer.name))
        # Update layer. If layer is a container, update inner layer.
        if layer.__class__.__name__ == 'TimeDistributed':
            layer.layer.trainable = trainable
        else:
            layer.trainable = trainable
        # Print trainable layer names
        if trainable and verbose > 0:
            log("{}{:20}   ({})".format(" " * indent, layer.name,
                                        layer.__class__.__name__))
def compile(self, learning_rate, momentum):
    """Gets the model ready for training. Adds losses, regularization, and
    metrics. Then calls the Keras compile() function.
    """

    # Compile SGD
    optimizer = keras.optimizers.SGD(
        lr=learning_rate, momentum=momentum,
        clipnorm=self.config.GRADIENT_CLIP_NORM)

    # Add Losses
    # First, clear previously set losses to avoid duplication
    self.keras_model._losses = []
    self.keras_model._per_input_losses = {}
    loss_names = [
        "rpn_class_loss", "rpn_bbox_loss",
        "obj_sal_seg_class_loss", "obj_sal_seg_bbox_loss", "obj_sal_seg_mask_loss"]
    for name in loss_names:
        layer = self.keras_model.get_layer(name)
        if layer.output in self.keras_model.losses:
            continue
        loss = (tf.reduce_mean(layer.output, keepdims=True)
                * self.config.LOSS_WEIGHTS.get(name, 1.))
        self.keras_model.add_loss(loss)

    # Add L2 Regularization
    # Skip gamma and beta weights of batch normalization layers.
    reg_losses = [
        keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
        for w in self.keras_model.trainable_weights
        if 'gamma' not in w.name and 'beta' not in w.name]
    self.keras_model.add_loss(tf.add_n(reg_losses))

    # Compile
    self.keras_model.compile(
        optimizer=optimizer,
        loss=[None] * len(self.keras_model.outputs))

    # Add metrics for losses
    for name in loss_names:
        if name in self.keras_model.metrics_names:
            continue
        layer = self.keras_model.get_layer(name)
        self.keras_model.metrics_names.append(name)
        loss = (tf.reduce_mean(layer.output, keepdims=True)
                * self.config.LOSS_WEIGHTS.get(name, 1.))
        self.keras_model.metrics_tensors.append(loss)

好了,到这里为止,关于 README.md 中首先要运行的 obj_sal_seg_branch/train.py 部分的代码就全部解读完毕。

总结来说完成的任务是:对原 mask r-cnn 去 mrcnn_ 头,在新数据集上先对个别层训练,然后对所有层微调。

第二部分 预处理 提取特征

siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency_第2张图片

pre_process/pre_process_obj_feat_GT.py

根据 README.md,第二步是运行 pre_process/pre_process_obj_feat_GT.py 文件。那么来看看它的源码吧:

首先是一堆路径的设置。这个类需要运行两次,第二次运行的时候把 data_split = "train" 注释掉,换成 data_split = "val"

DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/"    # Change to your location

if __name__ == '__main__':
    # add pre-trained weight path - backbone pre-trained on salient objects (binary, no rank)
    weight_path = ""

    # Run Script twice to generate pre-processed object features of GT objects for "train" and "val" data_splits
    data_split = "train"
    # data_split = "val"

    out_path = PRE_PROC_DATA_ROOT + "pre_process_feat/" + data_split + "/"

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    mode = "inference"
    config = RankModelConfig()
    log_path = "logs/"

然后通过 Model_Obj_Feat.build_obj_feat_model(config) 获得一个 Model 类型的实例

if __name__ == '__main__':
    ...
    keras_model = Model_Obj_Feat.build_obj_feat_model(config)
    model_name = "Obj_Feat_Net"

简单看看这个 Model 类型对象实例( keras_model )的结构:

def build_obj_feat_model(config):
    # *********************** INPUTS ***********************
    input_image = Input(shape=(config.NET_IMAGE_SIZE, config.NET_IMAGE_SIZE, 3), name="input_image")
    input_image_meta = Input(shape=[config.IMAGE_META_SIZE], name="input_image_meta")

    input_obj_rois = Input(shape=(config.SAL_OBJ_NUM, 4), name="input_obj_rois")
    # Normalize coordinates
    obj_rois = Lambda(lambda x: fpn_model_utils.norm_boxes_graph(x, K.shape(input_image)[1:3]))(input_obj_rois)

    # *********************** BACKBONE FEATURES ***********************
    # Generate Backbone features
    # backbone_feat = [P2, P3, P4, P5]
    # rpn_features = [P2, P3, P4, P5, P6]
    # P2: (?, 256, 256, 256)
    # P3: (?, 128, 128, 256)
    # P4: (?, 64, 64, 256)
    # P5: (?, 32, 32, 256)
    backbone_feat = generate_backbone_features(input_image, config)

    P2, P3, P4, P5 = backbone_feat

    # *********************** SALIENT OBJECT MASK BRANCH ***********************
    # Produce Object Segment Masks
    obj_seg_masks = ObjectSegmentationMaskBranch.build_fpn_mask_graph(obj_rois, backbone_feat,
                                                                      input_image_meta,
                                                                      config.MASK_POOL_SIZE,
                                                                      config.NUM_CLASSES,
                                                                      train_bn=config.TRAIN_BN)

    # ROIAlign ed Object Features
    obj_features = pyr_roi_align_graph(obj_rois, backbone_feat, input_image_meta,
                                       config.POOL_SIZE,
                                       train_bn=config.TRAIN_BN,
                                       fc_layers_size=config.FPN_CLASSIF_FC_LAYERS_SIZE)

    # Model
    inputs = [input_image, input_image_meta, input_obj_rois]
    outputs = [obj_seg_masks, obj_features, P5]
    model = Model(inputs=inputs, outputs=outputs, name="obj_feat_model")

    return model

这部分先从基础网络中获得 [P2, P3, P4, P5] 特征集合,然后将特征:

  • 输入 build_fpn_mask_graph 得到 mask(在代码中存放在 obj_seg_masks 变量中)。
  • 输入 pyr_roi_align_graph 得到 ROIAlign 后的特征(在代码中存放在 obj_features 变量中 )。

所以输出为: outputs = [obj_seg_masks, obj_features, P5] 。为啥要输出这仨?因为在后面的显著性排序网络中,需要用到这些特征。

然后我们再回到 pre_process/pre_process_obj_feat_GT.py 文件中,接着上面看代码:

if __name__ == '__main__':
    ...
    model = PreProcNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)

刚看完一个 keras_model ,马上又来一个 model ,跟第一部分中 SOSNetbuild_saliency_seg_model build 的模型的一个包装类一样,这个 model 对应的类 PreProcNet 是 keras_model(即 build_obj_feat_model build 的模型)的一个包装类,同样里面也装了一些方便 detect 时候调用的方法。

我们先接着把 pre_process/pre_process_obj_feat_GT.py 文件中剩余的看完,再来解读 PreProcNet 类中那些方法(不用看也可以根据前面 SOSNet 的例子猜出来,肯定也是有载入权重的方法、detect方法)

继续 pre_process/pre_process_obj_feat_GT.py :
调用model中的载入权重的方法。

if __name__ == '__main__':
    ...
    # Load weights
    print("Loading weights ", weight_path)
    model.load_weights(weight_path, by_name=True)

开头就给 mode 赋值了,肯定会进入这个分支。
调用 pre_process.Dataset 获得数据,然后把 dataset 传入 pre_process.DataGenerator 。后文将对这俩类做详解。

if __name__ == '__main__':
    ...
    if mode == "inference":
        # ********** Create Datasets
        # Train/Val Dataset
        dataset = Dataset(DATASET_ROOT, data_split)

        predictions = []

        num = len(dataset.img_ids)
        for i in range(num):

            image_id = dataset.img_ids[i]
            print(i + 1, " / ", num, " - ", image_id)

            input_data, gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list = DataGenerator.load_inference_data_obj_feat_gt(dataset, image_id, config)

调用 model 中的 detect 方法(对啊,这也是为啥设置成 ‘inference’ 模式,因为这部分是预处理,得到特征方便后面的显著性排序网络。而特征的得到,是靠 detect )

if __name__ == '__main__':
    ...
    if mode == "inference":
        ...
        for i in range(num):
			...
			result = model.detect(input_data, verbose=1)

把之前从 DataGenerator 中获得的其它数据都放进 result 里面。

if __name__ == '__main__':
    ...
    if mode == "inference":
        ...
        for i in range(num):
			...
            result["gt_ranks"] = gt_ranks
            result["sel_not_sal_obj_idx_list"] = sel_not_sal_obj_idx_list
            result["shuffled_indices"] = shuffled_indices
            result["chosen_obj_idx_order_list"] = chosen_obj_idx_order_list

最后把 result 存入本地文件。其中 pickle 是 python 中一个的工具,它能够实现任意对象与文本之间的相互转化,也可以实现任意对象与二进制之间的相互转化。也就是说,pickle 可以实现 Python 对象的存储及恢复。

if __name__ == '__main__':
    ...
    if mode == "inference":
        ...
        for i in range(num):
			...
            o_p = out_path + image_id
            with open(o_p, "wb") as f:
                pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)

到这里 pre_process_obj_feat_GT.py 的内容就结束了,总结来说,就是获得了提取特征的model对象。

然后通过Dataset 和 DataGenerator 获得一部分数据(gt_ranks、 sel_not_sal_obj_idx_list、shuffled_indices、chosen_obj_idx_order_list),

然后调用 detect 获得另一部分数据,或者说特征(obj_masks、obj_feat、P5)。最后把这些数据都用 pickle 存到本地。

接下来返回去看看在这个过程中涉及到的 detect 方法的具体代码、以及 Dataset 和 DataGenerator 的具体代码。

pre_process.PreProcNet

load_weights 和前面基本一样,不重复说了。
看下 detect(self, input_data, verbose=0)

model.detect

源码很短,真的很短。
返回一个字典,携带了预测结果,也就是:

  • obj_masks:[batch, roi_count, height, width, num_classes]
  • obj_feat:pooled features
  • P5:[batch, 32, 32, 256]
# Detection performed per single image
def detect(self, input_data, verbose=0):

    assert self.mode == "inference", "Create model in inference mode."

    if verbose:
        log("Processing image")
        log("image", input_data[0])

    detections = self.keras_model.predict(input_data, verbose=0)

    # Process detection
    obj_masks, obj_feat, P5 = detections
    result = {}
    result["obj_masks"] = obj_masks
    result["obj_feat"] = obj_feat
    result["P5"] = P5

    return result

pre_process.Dataset

先看 __init__

class Dataset(object):
    def __init__(self, dataset_root, data_split):

        self.dataset_root = dataset_root                    # Root folder of Dataset
        self.data_split = data_split

        self.load_dataset()

__init__ 中调用了 self.load_dataset() ,来看看这个方法的源码:

代码也很短,简洁明快。首先导入图片 id ,存入 self.img_ids ,然后导入排序的 ground_truth,存入 self.gt_rank_orders

最后导入显著性物体的分割数据,这个数据是存在 jason 文件中的。这里面可以得到 obj_bbox , obj_seg , _sal_obj_idx_list , _not_sal_obj_idx_list

def load_dataset(self):
    print("\nLoading Dataset...")

    image_file = self.data_split + "_images.txt"

    # Get list of image ids
    image_path = os.path.join(self.dataset_root, image_file)
    with open(image_path, "r") as f:
        image_names = [line.strip() for line in f.readlines()]

    self.img_ids = image_names
    print(self.img_ids)

    # Load Rank Order
    rank_order_root = self.dataset_root + "rank_order/" + self.data_split + "/"
    self.gt_rank_orders = self.load_rank_order_data(rank_order_root)

    # Load Object Data
    obj_seg_data_path = self.dataset_root + "obj_seg_data_" + self.data_split + ".json"

    self.obj_bboxes, self.obj_seg, self.sal_obj_idx_list, self.not_sal_obj_idx_list = self.load_object_seg_data(
        obj_seg_data_path)

其中详细的调用方法,比如 self.load_rank_order_data 到底是怎么 load 的就不在此详论了。

pre_process.DataGenerator

这里面有两个方法,分别是:

  • load_inference_data_obj_feat
  • load_inference_data_obj_feat_gt

在 pre_process/pre_process_obj_feat_GT.py 中调用的是带 _gt 的那个,所以暂时先只讲 load_inference_data_obj_feat_gt
看源码:

先通过 dataset 获得一系列数据,包括:①image、②gt_ranks、③sel_not_sal_obj_idx_list、④shuffled_indices、⑤chosen_obj_idx_order_list、⑥object_roi_masks

def load_inference_data_obj_feat_gt(dataset, image_id, config):
    image = dataset.load_image(image_id)

    gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list = dataset.load_gt_rank_order(image_id)

    object_roi_masks = dataset.load_object_roi_masks(image_id, sel_not_sal_obj_idx_list)

随后调用 fpn_network.utils 中的工具方法,进行一系列处理:

  • 对图片进行 corp、resize 处理
  • 对 mask 也进行 resize 处理
  • 根据 obj_mask 获得 obj_bbox
def load_inference_data_obj_feat_gt(dataset, image_id, config):
    image = dataset.load_image(image_id)
    ...
    original_shape = image.shape
    image, window, scale, padding, crop = utils.resize_image(
        image,
        min_dim=config.IMAGE_MIN_DIM,
        min_scale=config.IMAGE_MIN_SCALE,
        max_dim=config.IMAGE_MAX_DIM,
        mode=config.IMAGE_RESIZE_MODE)
    obj_mask = utils.resize_mask(object_roi_masks, scale, padding, crop)

    # bbox: [num_instances, (y1, x1, y2, x2)]
    obj_bbox = utils.extract_bboxes(obj_mask)

chosen_obj_idx_order_list 是 dataset 返回的一个数据,这是根据 ground truth 的显著性排序选出固定个(由config决定是多少个)个体,下面这段代码把这些被选中的 sal_obj 合成 batch

def load_inference_data_obj_feat_gt(dataset, image_id, config):
    image = dataset.load_image(image_id)
    ...
    # *********************** FILL REST, SHUFFLE ORDER ***********************
    # order is in salient objects then non-salient objects
    batch_obj_roi = np.zeros(shape=(config.SAL_OBJ_NUM, 4), dtype=np.int32)
    for i in range(len(chosen_obj_idx_order_list)):
        _idx = chosen_obj_idx_order_list[i]
        batch_obj_roi[_idx] = obj_bbox[i]

然后是对图片做标准化、生成 image_meta,这跟mask r-cnn里面的操作差不多。

然后就是合成 batch,返回这些数据。

def load_inference_data_obj_feat_gt(dataset, image_id, config):
    image = dataset.load_image(image_id)
    ...
    # Normalize image
    image = model_utils.mold_image(image.astype(np.float32), config)

    # Active classes
    active_class_ids = np.ones([config.NUM_CLASSES], dtype=np.int32)
    img_id = image_id
    img_id = int(img_id[-12:])
    # Image meta data
    image_meta = model_utils.compose_image_meta(img_id, original_shape, image.shape,
                                                window, scale, active_class_ids)

    # Expand input dimensions to consider batch
    image = np.expand_dims(image, axis=0)
    image_meta = np.expand_dims(image_meta, axis=0)
    batch_obj_roi = np.expand_dims(batch_obj_roi, axis=0)

    return [image, image_meta, batch_obj_roi], gt_ranks, sel_not_sal_obj_idx_list, shuffled_indices, chosen_obj_idx_order_list

到此为止,第二部分的代码就解读完成了。总结来说,第二部分的任务是提取特征以及对图片预处理,目的是方便之后输入显著性排序网络。

第三部分 训练显著性排序网络

siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency_第3张图片training 过程的最后一步是运行 train.py 文件

来看看这个文件的源码;
前面是一些设置,然后调用 Model_SAM_SMM.build_saliency_rank_model 获得 keras_model ,后文将对这个 model 的结构进行解析。然后将这个 keras_model 作为参数获得 ASRNet 的对象。显然跟前面套路一样,ASRNetModel_SAM_SMM.build_saliency_rank_model 的包装类,里面还封装了一些载入权重、训练、检测之类的方法。跟前面差不多,这里就不细讲了。

# Path to dataset
DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location

# Path to pre-processed data - object features
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/"    # Change to your location

if __name__ == '__main__':
    weight_path = ""    # add pre-trained weight path

    command = "train"
    config = RankModelConfig()
    log_path = "logs/"
    mode = "training"

    print("Loading Rank Model")
    keras_model = Model_SAM_SMM.build_saliency_rank_model(config, mode)
    model_name = "Rank_Model_SAM_SMM"
    model = ASRNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)

然后就是调用 ASRNet 里面的方法,载入权重。获得 dataset 和 dataGenerator 之后训练。逻辑都差不多。

if __name__ == '__main__':
    ...
    # Load weights
    print("Loading weights ", weight_path)
    model.load_weights(weight_path, by_name=True)

    # Train/Evaluate Model
    if command == "train":
        print("Start Training...")

        # ********** Create Datasets
        # Train Dataset
        train_dataset = Dataset(DATASET_ROOT, PRE_PROC_DATA_ROOT, "train")

        # Val Dataset
        val_dataset = Dataset(DATASET_ROOT, PRE_PROC_DATA_ROOT, "val")

        # ********** Parameters
        # Image Augmentation
        # Right/Left flip 50% of the time
        # augmentation = imgaug.augmenters.Fliplr(0.5)
        augmentation = None

        # ********** Create Data generators
        train_generator = DataGenerator.data_generator(train_dataset, config, shuffle=True,
                                                       augmentation=augmentation,
                                                       batch_size=config.BATCH_NUM)
        val_generator = DataGenerator.data_generator(val_dataset, config, shuffle=True,
                                                     batch_size=config.BATCH_NUM)

        # ********** Training  **********
        model.train(train_generator, val_generator,
                    learning_rate=config.LEARNING_RATE,
                    epochs=40,
                    layers='all')

下面放一下 Model_SAM_SMM.build_saliency_rank_model 得到的 model 的结构(唉,看不清的话,可以下载原图(提取码1111))
或者自己调用下面这句也生成同样的效果:

utils.plot_model(model, 'model.png', show_shapes=True)


最后的分类器的结构:
siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency_第4张图片
然后这部分到这里也结束了,具体的显著性分类网络的代码比较简单,就不解析了。(或许之后有时间会更新)

最后的一点思考

感觉整个网络用的 Dense 层非常多,但是一个分成 6 类(5个显著性等级和1个背景)的分类器,感觉没必要?但是我也没做实验,不晓得具体情况。

另外神经网络做排序,其实是不太合适的(参考《deep learning with python》一书的观点),在这里是把排序问题转换成了一个分类任务。虽然也可以实现目的。

以上。欢迎评论区讨论。(如果有人看的话。。。)

inference 部分已更新,见:siris 显著性排序网络代码解读(inference过程)

你可能感兴趣的:(卷积神经网络,python,神经网络,计算机视觉)