修改权重使用预训练模型权重

抛弃最后的输出层并非最佳方案。可以修改输出层的权重,以 mmdetection 使用预训练模型为例。

import torch

def faster_rcnn(num_classes):
    model_coco = torch.load("e:/14_model.pth")
    # weight
    for i in list(model_coco["state_dict"]):
        if 'bbox_head.fc_cls' in i:

            model_coco["state_dict"][i] = model_coco["state_dict"][i][:num_classes]
        elif 'bbox_head.fc_reg' in i:
            model_coco["state_dict"][i] = model_coco["state_dict"][i][:(num_classes * 4)]
    # save new model
    torch.save(model_coco, "../checkpoints/faster_rcnn_r50_fpn_1x_%d.pth" % num_classes)

def cascade_rcnn(num_classes):
    model_coco = torch.load("../checkpoints/cascade_rcnn_dconv_c3-c5_r101_fpn_1x_20190125-aaa877cc.pth")
    model_coco["state_dict"]["bbox_head.0.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.0.fc_cls.weight"][:num_classes, :]
    model_coco["state_dict"]["bbox_head.1.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.1.fc_cls.weight"][:num_classes, :]
    model_coco["state_dict"]["bbox_head.2.fc_cls.weight"] =model_coco["state_dict"]["bbox_head.2.fc_cls.weight"][:num_classes, :]
    model_coco["state_dict"]["bbox_head.0.fc_cls.bias"] = model_coco["state_dict"]["bbox_head.0.fc_cls.bias"][:num_classes]
    model_coco["state_dict"]["bbox_head.1.fc_cls.bias"] = model_coco["state_dict"]["bbox_head.1.fc_cls.bias"][:num_classes]
    model_coco["state_dict"]["bbox_head.2.fc_cls.bias"] =model_coco["state_dict"]["bbox_head.2.fc_cls.bias"][:num_classes]
    # save new model
    torch.save(model_coco, "../checkpoints/cascade_rcnn_dconv_c3-c5_r101_fpn_1x_%d.pth" % num_classes)
    
# cascade_rcnn(num_classes = 11)
faster_rcnn(4)

你可能感兴趣的:(python与人工睿智,机器学习入门与放弃,深度学习,计算机视觉)