加载已有的pth模型后为什么会重新训练?

保存和加载模型需要注意:

pytorch保存模型后加载模型遇到的大坑_模型保存参数再加载测试,结果跟保存前差了很多,可能的原因是什么-CSDN博客

加载已有的pth模型后为什么会重新训练?

假如有两个文件

train.py

定义神经网络
class Network(nn.Module):
    def __init__(self):
        super().__init__()

        ......

#生成对象,开始训练
network = Network()

......

#保存参数
torch.save(network,'cnn.pth')

test.py

from train import Network



network=torch.load('./cnn.pth')

在test.py运行后会重新训练train.py的模型,为什么会这样?

from train import Network 看似只导入了网络类,其实会把整个文件都导入了进来,而train.py里面训练模型的代码是全局的变量或对象,导入进来会重新运行,所以这么解决

1. 把train.py里面训练模型的代码封装成函数或者是局部的,直接加

if __name__ == "__main__":
定义神经网络
class Network(nn.Module):
    def __init__(self):
        super().__init__()

        ......

if __name__ == "__main__":

    #生成对象,开始训练
    network = Network()

    ......

    #保存参数
    torch.save(network,'cnn.pth')

2.因为保存的模型文件.pth迁移到别的地方使用需要用到定义的网络模型类,在test.py中不导入模型类了,直接把定义的模型类复制过来。

定义神经网络
class Network(nn.Module):
    def __init__(self):
        super().__init__()

        ......

#加载模型使用
network=torch.load('./cnn.pth')

你可能感兴趣的:(pytorch,保存和加载模型)