Bug解决:RuntimeError: Error(s) in loading state_dict for LadderNetv6: Missing key(s) in state_dict:

运行网络的时候出现了这个错误,虽然之前已经解决过了,还是出现的问题
在这里插入图片描述
下面是原来的代码:
别人说在load这个checkpoint之前要加上
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/' + check_path, map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    start_epoch = checkpoint['epoch']

于是我就试了一下,发现这个位置非常重要:
修改后的代码:
这两句不加在这个位置,怎么样都会出错

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/' + check_path, map_location='cpu')
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True
    net.load_state_dict(checkpoint['net'])
    start_epoch = checkpoint['epoch']

关于torch.nn.DataParallel()的使用,可以参考下文:
基于PyTorch的深度学习入门教程(六)——数据并行化
pytorch 多GPU训练总结(DataParallel的使用)

你可能感兴趣的:(人工智能,学术)