PyTorch 判断两个同类模型是否所有参数完全相同

PyTorch 判断两个同类模型是否所有参数完全相同

起因是遇到了一个bug,发现训练中途保存的模型和最终epoch之后的模型是一样的…,经过检查发现是因为model.state_dict() 是浅拷贝,返回的参数仍然会随着网络的训练而变化, 正确做法应该使用 deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

其中用到了两种判断两个同类模型是否所有参数完全相同:

assert any([p1.data.ne(p2.data).sum() > 0 for p1, p2 in zip(net1.parameters(), net2.parameters())])
assert str(net1.state_dict()) != str(net2.state_dict())

你可能感兴趣的:(pytorch,深度学习,人工智能)