mmdetection 修改预训练模型权重类别数

修改预训练权重类别数
import os
import torch
import argparse
def init_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--org_path", type=str, help="the path of pretrained model")
    parser.add_argument("--num_classes", type=int, default=26, help="number of classes")
    return parser.parse_args()
def modify_cascade_rcnn(model_coco, num_classes):
    model_coco["state_dict"]["bbox_head.0.fc_cls.weight"].resize_(num_classes,1024)                  
    model_coco["state_dict"]["bbox_head.1.fc_cls.weight"].resize_(num_classes,1024)                                                      
    model_coco["state_dict"]["bbox_head.2.fc_cls.weight"].resize_(num_classes,1024)
    # bias
    model_coco["state_dict"]["bbox_head.0.fc_cls.bias"].resize_(num_classes)                                                     
    model_coco["state_dict"]["bbox_head.1.fc_cls.bias"].resize_(num_classes)                                                   
    model_coco["state_dict"]["bbox_head.2.fc_cls.bias"].resize_(num_classes)
def main(args):
    save_dir = "/mnt/hdd/tangwei/autodrive/data/checkpoints/num%s/" %args.num_classes
    pth_dir = args.org_path
    model_coco = torch.load(pth_dir)
    base_name = os.path.basename(pth_dir)
    mode = base_name.split('_')
    if (mode[0] == 'faster') and (mode[1] == 'rcnn'):
        print('Current model is faster rcnn.')
        print('Converting ...')
        modify_faster_rcnn(model_coco, args.num_classes)
    elif (mode[0] == 'mask') and (mode[1] == 'rcnn'):
        print('Current model is mask rcnn.')
        print('Converting ...')        
        modify_mask_rcnn(model_coco, args.num_classes)
    elif (mode[0] == 'cascade') and (mode[1] == 'rcnn'):
        print('Current model is cascade rcnn.')
        print('Converting ...')        
        modify_cascade_rcnn(model_coco, args.num_classes)
    else:
        pass
    #save new model
    model_name = save_dir +  base_name.replace(base_name.split('_')[-1], 'coco_pretrained_weights_classes_') + str(args.num_classes) + ".pth" 
    torch.save(model_coco, model_name)
    print("Convert successful.")
if __name__ == '__main__':
    args = init_args() #
    main(args)

你可能感兴趣的:(pytorch,#,mmdetection,pytorch)