torch.save() or model.load_state_dict()有问题?都不是,是我数据集写错了

torch.save() or model.load_state_dict()有问题?都不是,是我数据集写错了

今天看pytorch 文档看到了参数或模型的保存,就想动手实践一下。然后神奇的问题出现了。
如果我训练完后立刻保存参数到磁盘,然后马上实例化一个新model并载入刚刚保存的参数,再用这个新的model作预测,预测正确率与训练时一样。但是如果我注释掉训练部分。直接实例化新model并载入参数。预测正确率不足1%。于是我就觉得是不是torch.save或者model.load_state_dict用法有什么要注意的地方。在百度上找了博客,各种有用的无用的修改都试过。发现问题依然存在。打印下参数。发现参数确实已经一摸一样地载入了。最后打印下预测结果,发现输出的分类引索是一模一样的。(我做的是分类问题)这时候我恍然大悟,一定是数据集载入有问题。说明一下错误:
模型任务是:给一个名字,判断是哪个国家的名字。输出的是这个国家在国家字典里的引索。
这个国家字典是这样创建的:先载入所有数据,取出国家列,然后去重。就得到了字典了。
但是,好像每次运行程序,去重后的国家顺序是不一样的(虽然不知道为什么,明明读入的顺序完全一样)。所以每次运行程序,一个国家对应的引索是不同的。
这样造成的问题是:如果训练和测试是在同一次程序运行,那么训练和测试用的国家字典就是一样的,否则,字典就不一样。无论在哪次运行中,对于相同的输入,模型的输出都是一样的。模型输出的引索,在字典a中对于国家1。但在字典b中就不一定了。为了保证字典一致。需要在做出字典后,对字典进行一次按字典序排序。由于字典序是唯一的,所以无论在哪次运行,都能保证所使用的字典是同一个字典。
原本只是测试一下参数保存,却测出了原来代码的隐藏大坑。

你可能感兴趣的:(个人错误记录)