PyTorch使用ResNet18提取图像特征并进行相似度计算

模型部分我参考的是https://blog.csdn.net/sunqiande88/article/details/80100891这篇文章,同样是在Cifar-10上训练。

一、不使用PyTorch中的预训练模型

将训练的模型保存下来接后面使用,保存方式:

torch.save(net.state_dict(), 'path')

加载方式

model = ResNet18()
model.load_state_dict(torch.load('path'))
model.eval() # 测试时候要加上

由于不是使用预训练模型所以输出特征层还是很好输出的,只需要将输出从fc层改为前一层即可:

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        features = out.detach()
        out = self.fc(out)
        return features # [1,512]

保存ONNX模型:

	model = ResNet18()
    model.load_state_dict(torch.load(from_path))
    model.eval()
    dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)
    torch.onnx.export(model, dummy_input, to_path, export_params=True, opset_version=10, do_constant_folding=True
                      , input_names=["input"], output_names=['output'])

二、使用PyTorch中的预训练模型

模型的保存与加载与上面相同,但是由于使用预训练模型无法修改输出,所以需要使用其他方式修改模型输出。

将训练的模型保存下来接后面使用,保存方式:

torch.save(net.state_dict(), 'path')

加载方式:

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, 10)
model.load_state_dict(torch.load(from_path))
model.eval()
# 获取全连接层的输入
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

model = models.resnet18(pretrained=False)
model.fc = Identity()
x = torch.randn(1, 3, 32, 32)
output = model(x)

保存ONNX模型:

# 初始化并加载模型
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(512, 10)
    model.load_state_dict(torch.load('path'))
    # 修改fc层输出
    model.fc = Identity()  # 修改FC层输出
    # 调试模式
    model.eval()
    # 输出到onnx
    dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)
    torch.onnx.export(model, dummy_input, to_path, export_params=True, opset_version=10, do_constant_folding=True
                      , input_names=['input'], output_names=['output'])

三、测试

测试.pth模型部分代码

# 加载图像
def load_image(img_path, transform=None):
    imgs = []
    for name in sorted(os.listdir(img_path)):
        img = Image.open(img_path + name).convert('RGB')
        if transform is not None:
            img = transform(img)
        else:
            img = transforms.ToTensor()(img)
        imgs.append(img)
    return imgs

def predict(imgs):
    model = ResNet18()
    model.load_state_dict(torch.load('path'))
    model.to(device)
    model.eval()
    imgs = torch.stack(imgs, 0).to(device)
    with torch.no_grad():
        predicts = model(imgs)
        print(predicted)
    return predicted

测试ONNX部分分为以下几种方式,
1、使用Python中的onnxruntime库

session = onnxruntime.InferenceSession('ONNX Path')
    inputs = {session.get_inputs()[0].name: img.numpy()}
    outs = session.run(None, inputs)
    return outs

2、使用Python中OpenCV(Version:4.2.0)中的接口

# 加载图像
def load_image_cv(img_path, mean=None, std=None):
    img = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR)
    if img.shape[2] > 1:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_ = (img / 255. - mean) / std # 这里根据训练时的格式进行处理
    return img_.astype(np.float32)

# 使用opencv运行onnx
def run_onnx_cv(from_path, img):
    net = cv2.dnn.readNetFromONNX(from_path)
    input = cv2.dnn.blobFromImage(img)
    net.setInput(input)
    outs = net.forward()
    return outs

3、使用C++版本的OpenCV(Version:4.2.0)

//加载图像
cv::Mat load_image_cv(const std::string& fileName, cv::Scalar mean, cv::Scalar std)
{
	cv::Mat img = cv::imread(fileName, cv::IMREAD_ANYCOLOR);
	if (img.empty()) {
		return cv::Mat();
	}
	img.convertTo(img, CV_32F, 1 / 255.);
	cv::subtract(img, mean, img);
	cv::divide(img, std, img);
	return img;
}

cv::Mat img = load_image_cv("path", cv::Scalar(0.4914, 0.4822, 0.4465), cv::Scalar(0.2023, 0.1994, 0.2010));
cv::dnn::Net net = cv::dnn::readNetFromONNX("ONNX Path");
net.setPreferableBackend(cv::dnn::Backend::DNN_BACKEND_OPENCV);
net.setPreferableTarget(cv::dnn::Target::DNN_TARGET_CPU);
cv::Mat input = cv::dnn::blobFromImage(img);
net.setInput(input);
cv::Mat predicted = net.forward();

效果图片

PyTorch使用ResNet18提取图像特征并进行相似度计算_第1张图片
PyTorch使用ResNet18提取图像特征并进行相似度计算_第2张图片
PyTorch使用ResNet18提取图像特征并进行相似度计算_第3张图片

你可能感兴趣的:(笔记,深度学习,机器学习,opencv)