.pth转.pt 【学习记录】

学习自用,欢迎指正

pytorch中训练的.pth模型不能在C++中直接调用,需要转换成.pt格式的文件,故有此文。

在转换之前需要知道的东西:

1.要转换的pth文件是从那个网络训练出来的?

2.训练这个网络时的输入参数是哪些,各是多少

了解之后利用如下代码进行转换:

import torch
from models.network_swinir import SwinIR #你训练的网络

model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8,
                    img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
                    mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')  # 训练此网络时往里输入的参数


state_dict = torch.load("S_x4.pth") # 要转换的文件路径
model.load_state_dict(state_dict, False)
model.eval()  # 切换到eval()

example = torch.rand(1, 3, 320, 480)  # 生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

你可能感兴趣的:(学习,深度学习,pytorch)