Pytorch实现RNN模型
代码
import torch
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, seq_len, input_size, hidden_size, output_size, num_layers, device):
super(RNN, self).__init__()
self._seq_len = seq_len
self._input_size = input_size
self._output_size = output_size
self._hidden_size = hidden_size
self._device = device
self._num_layers = num_layers
self.rnn = nn.RNN(
input_size=input_size,
hidden_size=self._hidden_size,
num_layers=self._num_layers,
batch_first=True
)
self.fc = nn.Linear(self._seq_len * self._hidden_size, self._output_size)
def forward(self, x, hidden_prev):
out, hidden_prev = self.rnn(x, hidden_prev)
out = out.contiguous().view(out.shape[0], -1)
out = self.fc(out)
return out, hidden_prev
seq_len = 10
batch_size = 20
input_size = 10
output_size = 10
hidden_size = 32
num_layers = 2
model = RNN(seq_len, input_size, hidden_size, output_size, num_layers, "cpu")
hidden_prev = torch.zeros(num_layers, batch_size, hidden_size).to("cpu")
model.eval()
input_names = ["input", "hidden_prev_in"]
output_names = ["output", "hidden_prev_out"]
x = torch.randn((batch_size, seq_len, input_size))
y, hidden_prev = model(x, hidden_prev)
print(x.shape)
print(hidden_prev.shape)
print(y.shape)
print(hidden_prev.shape)
torch.onnx.export(model, (x, hidden_prev), 'RNN.onnx', verbose=True, input_names=input_names, output_names=output_names,
dynamic_axes={'input':[0], 'hidden_prev_in':[1], 'output':[0], 'hidden_prev_out':[1]} )
import onnx
model = onnx.load("RNN.onnx")
print("load model done.")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))
print("check model done.")
运行结果
torch.Size([20, 10, 10])
torch.Size([2, 20, 32])
torch.Size([20, 10])
torch.Size([2, 20, 32])
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input
"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input hidden_prev
"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input output
"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:4322: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with RNN_TANH can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model.
+ "or define the initial states (h0/c0) as inputs of the model. "
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:688: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
graph, params_dict, GLOBALS.export_onnx_opset_version
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:1179: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
graph, params_dict, GLOBALS.export_onnx_opset_version
Exported graph: graph(%input : Float(*, 10, 10, strides=[100, 10, 1], requires_grad=0, device=cpu),
%hidden_prev.1 : Float(2, *, 32, strides=[640, 32, 1], requires_grad=1, device=cpu),
%fc.weight : Float(10, 320, strides=[320, 1], requires_grad=1, device=cpu),
%fc.bias : Float(10, strides=[1], requires_grad=1, device=cpu),
%onnx::RNN_58 : Float(1, 32, 10, strides=[320, 10, 1], requires_grad=0, device=cpu),
%onnx::RNN_59 : Float(1, 32, 32, strides=[1024, 32, 1], requires_grad=0, device=cpu),
%onnx::RNN_60 : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu),
%onnx::RNN_62 : Float(1, 32, 32, strides=[1024, 32, 1], requires_grad=0, device=cpu),
%onnx::RNN_63 : Float(1, 32, 32, strides=[1024, 32, 1], requires_grad=0, device=cpu),
%onnx::RNN_64 : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu)):
%/rnn/Transpose_output_0 : Float(10, *, 10, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/rnn/Transpose"](%input), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%onnx::RNN_13 : Tensor? = prim::Constant(), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/rnn/Constant"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/rnn/Constant_1"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_2"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Slice_output_0 : Float(1, *, 32, device=cpu) = onnx::Slice[onnx_name="/rnn/Slice"](%hidden_prev.1, %/rnn/Constant_1_output_0, %/rnn/Constant_2_output_0, %/rnn/Constant_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/RNN_output_0 : Float(10, 1, *, 32, device=cpu), %/rnn/RNN_output_1 : Float(1, *, 32, device=cpu) = onnx::RNN[activations=["Tanh"], hidden_size=32, onnx_name="/rnn/RNN"](%/rnn/Transpose_output_0, %onnx::RNN_58, %onnx::RNN_59, %onnx::RNN_60, %onnx::RNN_13, %/rnn/Slice_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_3_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_3"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Squeeze_output_0 : Float(10, *, 32, device=cpu) = onnx::Squeeze[onnx_name="/rnn/Squeeze"](%/rnn/RNN_output_0, %/rnn/Constant_3_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/rnn/Constant_4"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_5_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_5"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/rnn/Constant_6"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Slice_1_output_0 : Float(1, *, 32, device=cpu) = onnx::Slice[onnx_name="/rnn/Slice_1"](%hidden_prev.1, %/rnn/Constant_5_output_0, %/rnn/Constant_6_output_0, %/rnn/Constant_4_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/RNN_1_output_0 : Float(10, 1, *, 32, device=cpu), %/rnn/RNN_1_output_1 : Float(1, *, 32, device=cpu) = onnx::RNN[activations=["Tanh"], hidden_size=32, onnx_name="/rnn/RNN_1"](%/rnn/Squeeze_output_0, %onnx::RNN_62, %onnx::RNN_63, %onnx::RNN_64, %onnx::RNN_13, %/rnn/Slice_1_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Constant_7_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_7"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Squeeze_1_output_0 : Float(10, *, 32, device=cpu) = onnx::Squeeze[onnx_name="/rnn/Squeeze_1"](%/rnn/RNN_1_output_0, %/rnn/Constant_7_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/rnn/Transpose_1_output_0 : Float(*, 10, 32, strides=[320, 32, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/rnn/Transpose_1"](%/rnn/Squeeze_1_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%hidden_prev : Float(2, *, 32, strides=[640, 32, 1], requires_grad=1, device=cpu) = onnx::Concat[axis=0, onnx_name="/rnn/Concat"](%/rnn/RNN_output_1, %/rnn/RNN_1_output_1), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
%/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%/rnn/Transpose_1_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
%/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
%/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
%onnx::Unsqueeze_50 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()
%/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name="/Unsqueeze"](%/Gather_output_0, %onnx::Unsqueeze_50), scope: __main__.RNN::
%/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_1"](), scope: __main__.RNN::
%/Concat_output_0 : Long(2, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="/Concat"](%/Unsqueeze_output_0, %/Constant_1_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
%/Reshape_output_0 : Float(*, *, strides=[320, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%/rnn/Transpose_1_output_0, %/Concat_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
%output : Float(*, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/fc/Gemm"](%/Reshape_output_0, %fc.weight, %fc.bias), scope: __main__.RNN::/torch.nn.modules.linear.Linear::fc # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/linear.py:114:0
return (%output, %hidden_prev)
load model done.
graph torch_jit (
%input[FLOAT, input_dynamic_axes_1x10x10]
%hidden_prev.1[FLOAT, 2xhidden_prev.1_dim_1x32]
) initializers (
%fc.weight[FLOAT, 10x320]
%fc.bias[FLOAT, 10]
%onnx::RNN_58[FLOAT, 1x32x10]
%onnx::RNN_59[FLOAT, 1x32x32]
%onnx::RNN_60[FLOAT, 1x64]
%onnx::RNN_62[FLOAT, 1x32x32]
%onnx::RNN_63[FLOAT, 1x32x32]
%onnx::RNN_64[FLOAT, 1x64]
) {
%/rnn/Transpose_output_0 = Transpose[perm = [1, 0, 2]](%input)
%/rnn/Constant_output_0 = Constant[value = ]()
%/rnn/Constant_1_output_0 = Constant[value = ]()
%/rnn/Constant_2_output_0 = Constant[value = ]()
%/rnn/Slice_output_0 = Slice(%hidden_prev.1, %/rnn/Constant_1_output_0, %/rnn/Constant_2_output_0, %/rnn/Constant_output_0)
%/rnn/RNN_output_0, %/rnn/RNN_output_1 = RNN[activations = ['Tanh'], hidden_size = 32](%/rnn/Transpose_output_0, %onnx::RNN_58, %onnx::RNN_59, %onnx::RNN_60, %, %/rnn/Slice_output_0)
%/rnn/Constant_3_output_0 = Constant[value = ]()
%/rnn/Squeeze_output_0 = Squeeze(%/rnn/RNN_output_0, %/rnn/Constant_3_output_0)
%/rnn/Constant_4_output_0 = Constant[value = ]()
%/rnn/Constant_5_output_0 = Constant[value = ]()
%/rnn/Constant_6_output_0 = Constant[value = ]()
%/rnn/Slice_1_output_0 = Slice(%hidden_prev.1, %/rnn/Constant_5_output_0, %/rnn/Constant_6_output_0, %/rnn/Constant_4_output_0)
%/rnn/RNN_1_output_0, %/rnn/RNN_1_output_1 = RNN[activations = ['Tanh'], hidden_size = 32](%/rnn/Squeeze_output_0, %onnx::RNN_62, %onnx::RNN_63, %onnx::RNN_64, %, %/rnn/Slice_1_output_0)
%/rnn/Constant_7_output_0 = Constant[value = ]()
%/rnn/Squeeze_1_output_0 = Squeeze(%/rnn/RNN_1_output_0, %/rnn/Constant_7_output_0)
%/rnn/Transpose_1_output_0 = Transpose[perm = [1, 0, 2]](%/rnn/Squeeze_1_output_0)
%hidden_prev = Concat[axis = 0](%/rnn/RNN_output_1, %/rnn/RNN_1_output_1)
%/Shape_output_0 = Shape(%/rnn/Transpose_1_output_0)
%/Constant_output_0 = Constant[value = ]()
%/Gather_output_0 = Gather[axis = 0](%/Shape_output_0, %/Constant_output_0)
%onnx::Unsqueeze_50 = Constant[value = ]()
%/Unsqueeze_output_0 = Unsqueeze(%/Gather_output_0, %onnx::Unsqueeze_50)
%/Constant_1_output_0 = Constant[value = ]()
%/Concat_output_0 = Concat[axis = 0](%/Unsqueeze_output_0, %/Constant_1_output_0)
%/Reshape_output_0 = Reshape[allowzero = 0](%/rnn/Transpose_1_output_0, %/Concat_output_0)
%output = Gemm[alpha = 1, beta = 1, transB = 1](%/Reshape_output_0, %fc.weight, %fc.bias)
return %output, %hidden_prev
}
check model done.
C++调用ONNX
代码
vector<float> testOnnxRNN() {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(5);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
#ifdef _WIN32
const wchar_t* model_path = L"C:\\Users\\xxx\\Desktop\\RNN.onnx";
#else
const char* model_path = "C:\\Users\\xxx\\Desktop\\RNN.onnx";
#endif
wprintf(L"%s\n", model_path);
Ort::Session session(env, model_path, session_options);
Ort::AllocatorWithDefaultOptions allocator;
size_t num_input_nodes = session.GetInputCount();
size_t num_output_nodes = session.GetOutputCount();
std::vector<const char*> input_node_names = { "input" , "hidden_prev_in" };
std::vector<const char*> output_node_names = { "output" , "hidden_prev_out" };
const int input_size = 10;
const int output_size = 10;
const int batch_size = 1;
const int seq_len = 10;
const int num_layers = 2;
const int hidden_size = 32;
std::vector<int64_t> input_node_dims = { batch_size, seq_len, input_size };
size_t input_tensor_size = batch_size * seq_len * input_size;
std::vector<float> input_tensor_values(input_tensor_size);
for (unsigned int i = 0; i < input_tensor_size; i++) {
input_tensor_values[i] = (float)i / (input_tensor_size + 1);
}
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), 3);
assert(input_tensor.IsTensor());
std::vector<int64_t> hidden_prev_in_node_dims = { num_layers, batch_size, hidden_size };
size_t hidden_prev_in_tensor_size = num_layers * batch_size * hidden_size;
std::vector<float> hidden_prev_in_tensor_values(hidden_prev_in_tensor_size);
for (unsigned int i = 0; i < hidden_prev_in_tensor_size; i++) {
hidden_prev_in_tensor_values[i] = (float)i / (hidden_prev_in_tensor_size + 1);
}
auto mask_memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value hidden_prev_in_tensor = Ort::Value::CreateTensor<float>(mask_memory_info, hidden_prev_in_tensor_values.data(), hidden_prev_in_tensor_size, hidden_prev_in_node_dims.data(), 3);
assert(hidden_prev_in_tensor.IsTensor());
std::vector<Ort::Value> ort_inputs;
ort_inputs.push_back(std::move(input_tensor));
ort_inputs.push_back(std::move(hidden_prev_in_tensor));
vector<float> ret;
try
{
auto output_tensors = session.Run(Ort::RunOptions{ nullptr }, input_node_names.data(), ort_inputs.data(), ort_inputs.size(), output_node_names.data(), 2);
float* output = output_tensors[0].GetTensorMutableData<float>();
float* hidden_prev_out = output_tensors[1].GetTensorMutableData<float>();
for (int i = 0; i < output_size; i++) {
ret.emplace_back(output[i]);
std::cout << output[i] << " ";
}
std::cout << "\n";
}
catch (const std::exception& e)
{
std::cout << e.what() << std::endl;
}
return ret;
}
运行结果
C:\Users\xxx\Desktop\RNN.onnx
0.00296116 0.104443 -0.104239 0.249864 -0.155839 0.019295 0.0458037 -0.0596341 -0.129019 -0.014682