多GPU训练保存的参数,单GPU的环境,KeyError: ‘base.conv1.weight‘

def load_param(self, model_path):
        param_dict = torch.load(model_path)
# =============================================================================
#         for i in param_dict:
#             if 'fc' in i:
#                 continue
#             self.state_dict()[i].copy_(param_dict[i])
# =============================================================================
     # 源代码是多GPU训练,单gpu时出问题       
        for i in param_dict:
            j = i.replace("base.","")
            if 'fc' in i:
                continue
            if j in self.state_dict().keys():
                self.state_dict()[j].copy_(param_dict[i])

注释内的在运行时报错,KeyError: ‘base.conv1.weight’,是由于多GPU训练保存的参数单GPU的环境无法直接用。改为注释下面的

你可能感兴趣的:(解决程序bug,#,Ubuntu,深度优先,算法,leetcode)