ModuleNotFoundError: No module named ‘models‘ 的解决方法

参考博客:
ModuleNotFoundError: No module named ‘models‘解决torch.load问题【天坑】

保存与加载

使用 torch.save(model, “my_model.pth”) 命令可以保存整个模型。
这个保存/加载过程使用最直观的语法,涉及的代码最少。
以这种方式保存模型将使用Python的 pickle 模块保存整个model。
但是,在进行torch.load(“my_model.pth”)时,加载目录与保存目录要相同,这里的目录不是"my_model.pth",而是项目中定义模型所涉及的目录。举个例子:

load-test工程中有model_1文件夹,model_1文件夹中有yolo.py模块。
ModuleNotFoundError: No module named ‘models‘ 的解决方法_第1张图片
yolo.py:

import torch


class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1 = torch.nn.ReLU()
        self.max_pooling1 = torch.nn.MaxPool2d(2, 1)

        self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
        self.relu2 = torch.nn.ReLU()
        self.max_pooling2 = torch.nn.MaxPool2d(2, 1)

        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

在train.py中导入模型并保存整个模型:

from model_1.yolo import MyNet
import torch

net = MyNet()
torch.save(net, './weights/net.pt')  # 保存

将net.pt保存在weights文件夹中。
然后在train.py中加载模型:

import torch

model = torch.load('./weights/net.pt')

是成功的(也不需要 from model_1.yolo import MyNet )。

如果将model_1文件夹改为model_2文件夹,则出错。
ModuleNotFoundError: No module named ‘models‘ 的解决方法_第2张图片
出错的原因是导入模型时的目录与保存时的目录不一致。

解决方案

有时候需要将训练好的模型导入另一个工程中,但是该工程的文件夹与原文件夹相同且不方便更改,这时可以采取: 先加载模型,然后保存模型参数,再加载模型参数 的方法来正确加载模型。

先保持model_1文件夹不变。
train.py:

import torch

model = torch.load('./weights/net.pt')
torch.save(model.state_dict(), './weights/net_state_dict.pt')  # 保存

将model_1文件夹改为model_2文件夹。
train.py:

from model_2.yolo import MyNet
import torch

model = MyNet()
state_dict = torch.load('./weights/net_state_dict.pt')
model.load_state_dict(state_dict)

加载成功。

你可能感兴趣的:(重要,深度学习,pytorch,python)