pytorch并行化常见bug

state_dict = torch.load(opts.checkpoint)
try:
    trainer.net.load_state_dict(state_dict['net_param'])
except Exception:
    trainer.net = torch.nn.DataParallel(trainer.net)
    trainer.net.load_state_dict(state_dict['net_param'])

This is for dealing a checkpoint trained in parallel.

try: 
    out = trainer.net.forward()
except:
    out = trainer.net.module.forward()

Simialarly, the net need to be transformed to module for being compatible with paralelly trained model.

你可能感兴趣的:(pytorch)