Pytorch踩坑记录之模型载入

def load_state_dict(self, state_dict):
        ''' Self-write load state_dict '''

        for name, param in state_dict.items():
            if name in self._channel_dict:
                if 'bn' in name:
                    param = param.unsqueeze(0).data
                else:
                    param = param.data
                try:
                    self._channel_dict[name].copy_(param)
                except Exception:
                    raise RuntimeError('While copying the parameter named {}, '
                                       'whose dimensions in the model are {} and '
                                       'whose dimensions in the checkpoint are {}.'
                                       .format(name, self._channel_dict[name].size(), param.size()))

 refer-link  : https://blog.csdn.net/a137376864/article/details/78654618

你可能感兴趣的:(深度学习)