加载模型参数到使用 DataParallel(model) 的模型的两种方法的代码

load model weights into DataParallel(model) 的两种方法的代码

方法一
G = Generator().to(device)
G.load_state_dict(torch.load(args.model_path))
G = nn.DataParallel(G) # 这样在传递时 G 的参数会被重置吗?经测试,没有明显差别
方法二
G = Generator().to(device)
G = nn.DataParallel(G)
G.module.load_state_dict(torch.load(args.model_path))

你可能感兴趣的:(深度学习,pytorch)