RuntimeError: Error(s) in loading state_dict for ResNet


Author :Horizon Max

编程技巧篇:各种操作小结

机器视觉篇:会变魔术 OpenCV

深度学习篇:简单入门 PyTorch

神经网络篇:经典网络模型

算法篇:再忙也别忘了 LeetCode


错误提示

RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var"

后面接着一大堆网络结构相关的东西


错误原因

load_state_dict 中 strict 参数默认是 True :

表示预训练模型的层和自己定义的网络结构层严格对应相等(如层名和维度)

所以当我们修改了网络结构后,如果strict为True的时候就会报错


    def load_state_dict(self, state_dict, strict=True):
        r"""Copies parameters and buffers from :attr:`state_dict` into
        this module and its descendants. If :attr:`strict` is ``True``, then
        the keys of :attr:`state_dict` must exactly match the keys returned
        by this module's :meth:`~torch.nn.Module.state_dict` function.
        Arguments:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            strict (bool, optional): whether to strictly enforce that the keys
                in :attr:`state_dict` match the keys returned by this module's
                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

解决方案

load_state_dict 中 strict 参数设置为 False

修改前:

model.load_state_dict(torch.load(model_path))

修改后:

model.load_state_dict(torch.load(model_path), False)

喜欢的 留个 关注 、 加 点赞 哦 ~



你可能感兴趣的:(又来一个,Bug,python,深度学习,人工智能)