pytorch模型转tensorflow pb模型

#1.配置安装

  • conda create -n pth_pb python=3.7

  • pip install tensorflow==2.1.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • pip install tensorflow-addons==0.9.1 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • pip install onnx-tf==1.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • pip install onnx==1.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

  • conda install pytorch torchvision #(这里我用的版本为pytorch1.4, torchvision0.5.0)

  • 注意:如果你的python版本为3.8,安装onnx时会出现以下错误,降低python版本后解决

    问题:python setup.py egg_info Check the logs for full command output.

    解决:conda install python=3.7

    问题:No module named ‘pip._internal.cli.main’

    解决:easy_install pip

#2.代码
pth转onnx.py

import torchvision
import torch.onnx
import torch.nn as nn
def resnet50():
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = nn.Linear(2048, 2)
    return model
model = resnet50()
model = torch.nn.DataParallel(model)
pthfile = r’my.pth'
loaded_model = torch.load(pthfile, map_location='cpu')
model.load_state_dict(loaded_model['state_dict'])
input = torch.randn(1, 3, 200, 200)
input_names = ["head_input"]
output_names = ["output"]
onnx_filename = “my.onnx"
torch.onnx.export(model.module, input, onnx_filename, verbose=True, input_names=input_names, output_names=output_names)

onnx转pb.py

import onnx
from onnx_tf.backend import prepare

def onnx2pb(onnx_input_path, pb_output_path):
    onnx_model = onnx.load(onnx_input_path)  # load onnx model
    tf_exp = prepare(onnx_model)  # prepare tf representation
    tf_exp.export_graph(pb_output_path)  # export the model

if __name__ == "__main__":
    onnx_input_path = 'test.onnx'
    pb_output_path = 'test.pb'
    onnx2pb(onnx_input_path, pb_output_path)

pb_predict.py

import tensorflow as tf
from torchvision import transforms
import numpy as np
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((200, 200)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


with tf.Graph().as_default():
    # output_graph_def = tf.GraphDef() #tensorflow1.4版本
    output_graph_def = tf.compat.v1.GraphDef()
    output_graph_path = ‘my.pb'
    with open(output_graph_path, 'rb') as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")
    # with tf.Session() as sess:  #tensorflow1.4版本
    with tf.compat.v1.Session() as sess:
        image = “demo.jpg"
        image_np = Image.open(image)
        img_input = transform(image_np).unsqueeze(0)
        image_np_expanded = img_input.numpy()
        # sess.run(tf.global_variables_initializer())  #tensorflow1.4版本
        sess.run(tf.compat.v1.global_variables_initializer())
        input = sess.graph.get_tensor_by_name("head_input:0")
        output = sess.graph.get_tensor_by_name("output:0")
        predictions = sess.run(output, feed_dict={input: image_np_expanded})
        index = np.argmax(predictions)
        print("predictions:", predictions)
        print("index:", index)

  • 注意:其中onnx转py中会出现警告,这里并没有评测警告的影响。
    使用第二种方法
    git clone https://github.com/onnx/onnx-tensorflow.git
    cd onnx-tensorflow
    pip install -e .
    onnx-tf convert -i /Users/haobing1/myMac/model/1/resnet50_epoch_100.onnx -o /Users/haobing1/myMac/model/1/res.pb
    无警告onnx转换pb

你可能感兴趣的:(算法,pytorch,项目环境配置,pytorch,tensorflow,深度学习)