pytorch model.load_state_dict报错

pytorch加载模型的时候如果模型里边使用了一些判断,判断作为选择执行条件,但是也保存到模型里面了,但是调用的时候不选择判断条件里边的网络并且使用load_state_dict,会报错,有些算子找不到名称。如:

if backbone == "mobilenet":
    self.backbone = mobilenet()
    flat_shape = 1024
    elif backbone == "inception_resnetv1":
    self.backbone = inception_resnet()
else:
    raise ValueError('Unsupported backbone - `{}`, Use mobilenet, inception_resnetv1.'.format(backbone))
    self.avg = nn.AdaptiveAvgPool2d((1,1))
    self.Bottleneck = nn.Linear(flat_shape, embedding_size,bias=False)
    self.last_bn = nn.BatchNorm1d(embedding_size, eps=0.001, momentum=0.1, affine=True)
    if mode == "train": # 判断条件,测试时,不加载全连接
        self.classifier = nn.Linear(embedding_size, num_classes)

可以加入strict=False选项,规避网络中没有调用的算子:

model2.load_state_dict(state_dict2, strict=False)

你可能感兴趣的:(编程语言,深度学习,pytorch,python,人工智能)