Pytorch实现LSTM预测模型并使用C++相应的ONNX模型推理

Pytorch实现RNN模型

代码

import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_size, output_size, out_channels, num_layers, device):
        super(LSTM, self).__init__()
        self.device = device
        self.input_size = input_size
        self.hidden_size = input_size
        self.num_layers = num_layers
        self.output_size = output_size

        self.lstm = nn.LSTM(input_size=self.input_size,
                            hidden_size=self.hidden_size,
                            num_layers=self.num_layers,
                            batch_first=True)

        self.out_channels = out_channels

        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)

        out, _ = self.lstm(x, (h0, c0))

        if self.out_channels == 1:
            out = out[:, -1, :]
            return out
        
        return out


batch_size = 20
input_size = 10
output_size = 10
num_layers = 2
out_channels = 1

model = LSTM(input_size, output_size, out_channels, num_layers, "cpu")
model.eval() 

input_names = ["input"]
output_names  = ["output"]

x = torch.randn((batch_size, input_size, output_size))
print(x.shape)
y = model(x)
print(y.shape)

torch.onnx.export(model, x, 'LSTM.onnx', verbose=True, input_names=input_names, output_names=output_names,
  dynamic_axes={'input':[0], 'output':[0]} )

import onnx
model = onnx.load("LSTM.onnx")
print("load model done.")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))
print("check model done.")

运行结果

torch.Size([20, 10, 10])
torch.Size([20, 10])
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input
  "No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input output
  "No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:4322: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. 
  + "or define the initial states (h0/c0) as inputs of the model. "
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:688: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  graph, params_dict, GLOBALS.export_onnx_opset_version
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:1179: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  graph, params_dict, GLOBALS.export_onnx_opset_version
Exported graph: graph(%input : Float(*, 10, 10, strides=[100, 10, 1], requires_grad=0, device=cpu),
      %onnx::LSTM_193 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
      %onnx::LSTM_194 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
      %onnx::LSTM_195 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu),
      %onnx::LSTM_213 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
      %onnx::LSTM_214 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
      %onnx::LSTM_215 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu)):
  %/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%input), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0
  %/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0
  %/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0
  %/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name="/Constant_1"](), scope: __main__.LSTM::
  %onnx::Unsqueeze_18 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()
  %/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name="/Unsqueeze"](%/Gather_output_0, %onnx::Unsqueeze_18), scope: __main__.LSTM::
  %/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={10}, onnx_name="/Constant_2"](), scope: __main__.LSTM::
  %/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="/Concat"](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0
  %/ConstantOfShape_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name="/ConstantOfShape"](%/Concat_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0
  %/Cast_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::Cast[to=1, onnx_name="/Cast"](%/ConstantOfShape_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0
  %/lstm/Transpose_output_0 : Float(10, *, 10, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose"](%input), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %onnx::LSTM_26 : Tensor? = prim::Constant(), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_1"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_2"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Slice_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice"](%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_3_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_3"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_4"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_5_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_5"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Slice_1_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_1"](%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/LSTM_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM"](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %onnx::LSTM_26, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_6"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Squeeze_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze"](%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_7_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_7"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_8_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_8"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_9_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_9"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Slice_2_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_2"](%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_10_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_10"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_11_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_11"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_12_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_12"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Slice_3_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_3"](%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/LSTM_1_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_1_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_1_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM_1"](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %onnx::LSTM_26, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Constant_13_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_13"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Squeeze_1_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze_1"](%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/lstm/Transpose_1_output_0 : Float(*, 10, 10, strides=[10, 200, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose_1"](%/lstm/Squeeze_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0
  %/Constant_3_output_0 : Long(device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_3"](), scope: __main__.LSTM::
  %output : Float(*, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gather[axis=1, onnx_name="/Gather_1"](%/lstm/Transpose_1_output_0, %/Constant_3_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:29:0
  return (%output)

load model done.
graph torch_jit (
  %input[FLOAT, input_dynamic_axes_1x10x10]
) initializers (
  %onnx::LSTM_193[FLOAT, 1x40x10]
  %onnx::LSTM_194[FLOAT, 1x40x10]
  %onnx::LSTM_195[FLOAT, 1x80]
  %onnx::LSTM_213[FLOAT, 1x40x10]
  %onnx::LSTM_214[FLOAT, 1x40x10]
  %onnx::LSTM_215[FLOAT, 1x80]
) {
  %/Shape_output_0 = Shape(%input)
  %/Constant_output_0 = Constant[value = ]()
  %/Gather_output_0 = Gather[axis = 0](%/Shape_output_0, %/Constant_output_0)
  %/Constant_1_output_0 = Constant[value = ]()
  %onnx::Unsqueeze_18 = Constant[value = ]()
  %/Unsqueeze_output_0 = Unsqueeze(%/Gather_output_0, %onnx::Unsqueeze_18)
  %/Constant_2_output_0 = Constant[value = ]()
  %/Concat_output_0 = Concat[axis = 0](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0)
  %/ConstantOfShape_output_0 = ConstantOfShape[value = ](%/Concat_output_0)
  %/Cast_output_0 = Cast[to = 1](%/ConstantOfShape_output_0)
  %/lstm/Transpose_output_0 = Transpose[perm = [1, 0, 2]](%input)
  %/lstm/Constant_output_0 = Constant[value = ]()
  %/lstm/Constant_1_output_0 = Constant[value = ]()
  %/lstm/Constant_2_output_0 = Constant[value = ]()
  %/lstm/Slice_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0)
  %/lstm/Constant_3_output_0 = Constant[value = ]()
  %/lstm/Constant_4_output_0 = Constant[value = ]()
  %/lstm/Constant_5_output_0 = Constant[value = ]()
  %/lstm/Slice_1_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0)
  %/lstm/LSTM_output_0, %/lstm/LSTM_output_1, %/lstm/LSTM_output_2 = LSTM[hidden_size = 10](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0)
  %/lstm/Constant_6_output_0 = Constant[value = ]()
  %/lstm/Squeeze_output_0 = Squeeze(%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0)
  %/lstm/Constant_7_output_0 = Constant[value = ]()
  %/lstm/Constant_8_output_0 = Constant[value = ]()
  %/lstm/Constant_9_output_0 = Constant[value = ]()
  %/lstm/Slice_2_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0)
  %/lstm/Constant_10_output_0 = Constant[value = ]()
  %/lstm/Constant_11_output_0 = Constant[value = ]()
  %/lstm/Constant_12_output_0 = Constant[value = ]()
  %/lstm/Slice_3_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0)
  %/lstm/LSTM_1_output_0, %/lstm/LSTM_1_output_1, %/lstm/LSTM_1_output_2 = LSTM[hidden_size = 10](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0)
  %/lstm/Constant_13_output_0 = Constant[value = ]()
  %/lstm/Squeeze_1_output_0 = Squeeze(%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0)
  %/lstm/Transpose_1_output_0 = Transpose[perm = [1, 0, 2]](%/lstm/Squeeze_1_output_0)
  %/Constant_3_output_0 = Constant[value = ]()
  %output = Gather[axis = 1](%/lstm/Transpose_1_output_0, %/Constant_3_output_0)
  return %output
}
check model done.

C++调用ONNX

实现代码

vector<float> testOnnxLSTM(std::vector<std::vector<std::vector<float>>>& inputs) 
{
    //设置为VERBOSE,方便控制台输出时看到是使用了cpu还是gpu执行
    //Ort::Env env(ORT_LOGGING_LEVEL_VERBOSE, "test");
    Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");
    Ort::SessionOptions session_options;

    session_options.SetIntraOpNumThreads(5); // 使用五个线程执行op,提升速度
    // 第二个参数代表GPU device_id = 0,注释这行就是cpu执行
    //OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);

    #ifdef _WIN32
        const wchar_t* model_path = L"C:\\Users\\xxx\\Desktop\\LSTM.onnx";
    #else
        const char* model_path = "C:\\Users\\xxx\\Desktop\\LSTM.onnx";
    #endif

    wprintf(L"%s\n", model_path);

    Ort::Session session(env, model_path, session_options);

    const char* input_names[] = { "input" }; 
    const char* output_names[] = { "output" };

    const int input_size = 10;
    const int output_size = 10;
    const int batch_size = 1;
    const int seq_len = 10;

    std::array<float, batch_size* seq_len* input_size> input_matrix;
    std::array<float, batch_size* output_size> output_matrix;

    std::array<int64_t, 3> input_shape{ batch_size, seq_len, input_size };
    std::array<int64_t, 2> output_shape{ batch_size, output_size };

    for (int i = 0; i < batch_size; i++)
        for (int j = 0; j < seq_len; j++)
            for (int k = 0; k < input_size; k++)
                input_matrix[i * seq_len * input_size + j * input_size + k] = inputs[i][j][k];

    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());

    try
    {
        Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());
        session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1); 
    }
    catch (const std::exception& e)
    {
        std::cout << e.what() << std::endl;
    }

    std::cout << "get data from LSTM onnx: \n";
    vector<float> ret;
    for (int i = 0; i < output_size; i++) {
        ret.emplace_back(output_matrix[i]);
        std::cout << ret[i] << "\t";
    }
    std::cout << "\n";

    return ret;
}

调用代码

   std::vector<std::vector<std::vector<float>>> data;
   for (int i = 0; i < 1; i++) {
       std::vector<std::vector<float>> t1;
       for (int j = 0; j < 10; j++) {
           std::vector<float> t2;
           for (int k = 0; k < 10; k++) {
               t2.push_back(1.0 * k * j / 20);
           }
           t1.push_back(t2);
       }
       data.push_back(t1);
   }

   for (auto& i : data) {
       for (auto& j : i) {
           for (auto& k : j) {
               std::cout << k << "\t";
           }
           std::cout << "\n";
       }
       std::cout << "\n";
   }
   auto ret = testOnnxLSTM(data);

测试结果

0       0       0       0       0       0       0       0       0       0
0       0.05    0.1     0.15    0.2     0.25    0.3     0.35    0.4     0.45
0       0.1     0.2     0.3     0.4     0.5     0.6     0.7     0.8     0.9
0       0.15    0.3     0.45    0.6     0.75    0.9     1.05    1.2     1.35
0       0.2     0.4     0.6     0.8     1       1.2     1.4     1.6     1.8
0       0.25    0.5     0.75    1       1.25    1.5     1.75    2       2.25
0       0.3     0.6     0.9     1.2     1.5     1.8     2.1     2.4     2.7
0       0.35    0.7     1.05    1.4     1.75    2.1     2.45    2.8     3.15
0       0.4     0.8     1.2     1.6     2       2.4     2.8     3.2     3.6
0       0.45    0.9     1.35    1.8     2.25    2.7     3.15    3.6     4.05

C:\Users\xxx\Desktop\LSTM.onnx
get data from LSTM onnx:
0.000401703 0.00102207 0.0011015 -0.000503412 -0.000911839 -0.0011367 -0.000309185 0.000591398 -0.000362981 -4.81475e-05

你可能感兴趣的:(pytorch,lstm,c++)