Pytorch实现RNN模型
代码
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_size, output_size, out_channels, num_layers, device):
super(LSTM, self).__init__()
self.device = device
self.input_size = input_size
self.hidden_size = input_size
self.num_layers = num_layers
self.output_size = output_size
self.lstm = nn.LSTM(input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True)
self.out_channels = out_channels
self.fc = nn.Linear(self.hidden_size, self.output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
out, _ = self.lstm(x, (h0, c0))
if self.out_channels == 1:
out = out[:, -1, :]
return out
return out
batch_size = 20
input_size = 10
output_size = 10
num_layers = 2
out_channels = 1
model = LSTM(input_size, output_size, out_channels, num_layers, "cpu")
model.eval()
input_names = ["input"]
output_names = ["output"]
x = torch.randn((batch_size, input_size, output_size))
print(x.shape)
y = model(x)
print(y.shape)
torch.onnx.export(model, x, 'LSTM.onnx', verbose=True, input_names=input_names, output_names=output_names,
dynamic_axes={'input':[0], 'output':[0]} )
import onnx
model = onnx.load("LSTM.onnx")
print("load model done.")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))
print("check model done.")
运行结果
torch.Size([20, 10, 10])
torch.Size([20, 10])
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input
"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input output
"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:4322: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model.
+ "or define the initial states (h0/c0) as inputs of the model. "
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:688: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
graph, params_dict, GLOBALS.export_onnx_opset_version
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:1179: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
graph, params_dict, GLOBALS.export_onnx_opset_version
Exported graph: graph(%input : Float(*, 10, 10, strides=[100, 10, 1], requires_grad=0, device=cpu),
%onnx::LSTM_193 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
%onnx::LSTM_194 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
%onnx::LSTM_195 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu),
%onnx::LSTM_213 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
%onnx::LSTM_214 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),
%onnx::LSTM_215 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu)):
%/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%input), scope: __main__.LSTM::
%/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.LSTM::
%/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.LSTM::
%/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name="/Constant_1"](), scope: __main__.LSTM::
%onnx::Unsqueeze_18 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()
%/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name="/Unsqueeze"](%/Gather_output_0, %onnx::Unsqueeze_18), scope: __main__.LSTM::
%/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={10}, onnx_name="/Constant_2"](), scope: __main__.LSTM::
%/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="/Concat"](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0), scope: __main__.LSTM::
%/ConstantOfShape_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name="/ConstantOfShape"](%/Concat_output_0), scope: __main__.LSTM::
%/Cast_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::Cast[to=1, onnx_name="/Cast"](%/ConstantOfShape_output_0), scope: __main__.LSTM::
%/lstm/Transpose_output_0 : Float(10, *, 10, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose"](%input), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%onnx::LSTM_26 : Tensor? = prim::Constant(), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_1"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_2"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Slice_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice"](%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_3_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_3"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_4"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_5_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_5"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Slice_1_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_1"](%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/LSTM_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM"](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %onnx::LSTM_26, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_6"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Squeeze_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze"](%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_7_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_7"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_8_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_8"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_9_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_9"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Slice_2_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_2"](%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_10_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_10"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_11_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_11"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_12_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_12"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Slice_3_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_3"](%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/LSTM_1_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_1_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_1_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM_1"](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %onnx::LSTM_26, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Constant_13_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_13"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Squeeze_1_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze_1"](%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/lstm/Transpose_1_output_0 : Float(*, 10, 10, strides=[10, 200, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose_1"](%/lstm/Squeeze_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm
%/Constant_3_output_0 : Long(device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_3"](), scope: __main__.LSTM::
%output : Float(*, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gather[axis=1, onnx_name="/Gather_1"](%/lstm/Transpose_1_output_0, %/Constant_3_output_0), scope: __main__.LSTM::
return (%output)
load model done.
graph torch_jit (
%input[FLOAT, input_dynamic_axes_1x10x10]
) initializers (
%onnx::LSTM_193[FLOAT, 1x40x10]
%onnx::LSTM_194[FLOAT, 1x40x10]
%onnx::LSTM_195[FLOAT, 1x80]
%onnx::LSTM_213[FLOAT, 1x40x10]
%onnx::LSTM_214[FLOAT, 1x40x10]
%onnx::LSTM_215[FLOAT, 1x80]
) {
%/Shape_output_0 = Shape(%input)
%/Constant_output_0 = Constant[value = ]()
%/Gather_output_0 = Gather[axis = 0](%/Shape_output_0, %/Constant_output_0)
%/Constant_1_output_0 = Constant[value = ]()
%onnx::Unsqueeze_18 = Constant[value = ]()
%/Unsqueeze_output_0 = Unsqueeze(%/Gather_output_0, %onnx::Unsqueeze_18)
%/Constant_2_output_0 = Constant[value = ]()
%/Concat_output_0 = Concat[axis = 0](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0)
%/ConstantOfShape_output_0 = ConstantOfShape[value = ](%/Concat_output_0)
%/Cast_output_0 = Cast[to = 1](%/ConstantOfShape_output_0)
%/lstm/Transpose_output_0 = Transpose[perm = [1, 0, 2]](%input)
%/lstm/Constant_output_0 = Constant[value = ]()
%/lstm/Constant_1_output_0 = Constant[value = ]()
%/lstm/Constant_2_output_0 = Constant[value = ]()
%/lstm/Slice_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0)
%/lstm/Constant_3_output_0 = Constant[value = ]()
%/lstm/Constant_4_output_0 = Constant[value = ]()
%/lstm/Constant_5_output_0 = Constant[value = ]()
%/lstm/Slice_1_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0)
%/lstm/LSTM_output_0, %/lstm/LSTM_output_1, %/lstm/LSTM_output_2 = LSTM[hidden_size = 10](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0)
%/lstm/Constant_6_output_0 = Constant[value = ]()
%/lstm/Squeeze_output_0 = Squeeze(%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0)
%/lstm/Constant_7_output_0 = Constant[value = ]()
%/lstm/Constant_8_output_0 = Constant[value = ]()
%/lstm/Constant_9_output_0 = Constant[value = ]()
%/lstm/Slice_2_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0)
%/lstm/Constant_10_output_0 = Constant[value = ]()
%/lstm/Constant_11_output_0 = Constant[value = ]()
%/lstm/Constant_12_output_0 = Constant[value = ]()
%/lstm/Slice_3_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0)
%/lstm/LSTM_1_output_0, %/lstm/LSTM_1_output_1, %/lstm/LSTM_1_output_2 = LSTM[hidden_size = 10](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0)
%/lstm/Constant_13_output_0 = Constant[value = ]()
%/lstm/Squeeze_1_output_0 = Squeeze(%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0)
%/lstm/Transpose_1_output_0 = Transpose[perm = [1, 0, 2]](%/lstm/Squeeze_1_output_0)
%/Constant_3_output_0 = Constant[value = ]()
%output = Gather[axis = 1](%/lstm/Transpose_1_output_0, %/Constant_3_output_0)
return %output
}
check model done.
C++调用ONNX
实现代码
vector<float> testOnnxLSTM(std::vector<std::vector<std::vector<float>>>& inputs)
{
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(5);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#ifdef _WIN32
const wchar_t* model_path = L"C:\\Users\\xxx\\Desktop\\LSTM.onnx";
#else
const char* model_path = "C:\\Users\\xxx\\Desktop\\LSTM.onnx";
#endif
wprintf(L"%s\n", model_path);
Ort::Session session(env, model_path, session_options);
const char* input_names[] = { "input" };
const char* output_names[] = { "output" };
const int input_size = 10;
const int output_size = 10;
const int batch_size = 1;
const int seq_len = 10;
std::array<float, batch_size* seq_len* input_size> input_matrix;
std::array<float, batch_size* output_size> output_matrix;
std::array<int64_t, 3> input_shape{ batch_size, seq_len, input_size };
std::array<int64_t, 2> output_shape{ batch_size, output_size };
for (int i = 0; i < batch_size; i++)
for (int j = 0; j < seq_len; j++)
for (int k = 0; k < input_size; k++)
input_matrix[i * seq_len * input_size + j * input_size + k] = inputs[i][j][k];
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());
try
{
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());
session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1);
}
catch (const std::exception& e)
{
std::cout << e.what() << std::endl;
}
std::cout << "get data from LSTM onnx: \n";
vector<float> ret;
for (int i = 0; i < output_size; i++) {
ret.emplace_back(output_matrix[i]);
std::cout << ret[i] << "\t";
}
std::cout << "\n";
return ret;
}
调用代码
std::vector<std::vector<std::vector<float>>> data;
for (int i = 0; i < 1; i++) {
std::vector<std::vector<float>> t1;
for (int j = 0; j < 10; j++) {
std::vector<float> t2;
for (int k = 0; k < 10; k++) {
t2.push_back(1.0 * k * j / 20);
}
t1.push_back(t2);
}
data.push_back(t1);
}
for (auto& i : data) {
for (auto& j : i) {
for (auto& k : j) {
std::cout << k << "\t";
}
std::cout << "\n";
}
std::cout << "\n";
}
auto ret = testOnnxLSTM(data);
测试结果
0 0 0 0 0 0 0 0 0 0
0 0.05 0.1 0.15 0.2 0.25 0.3 0.35 0.4 0.45
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9
0 0.15 0.3 0.45 0.6 0.75 0.9 1.05 1.2 1.35
0 0.2 0.4 0.6 0.8 1 1.2 1.4 1.6 1.8
0 0.25 0.5 0.75 1 1.25 1.5 1.75 2 2.25
0 0.3 0.6 0.9 1.2 1.5 1.8 2.1 2.4 2.7
0 0.35 0.7 1.05 1.4 1.75 2.1 2.45 2.8 3.15
0 0.4 0.8 1.2 1.6 2 2.4 2.8 3.2 3.6
0 0.45 0.9 1.35 1.8 2.25 2.7 3.15 3.6 4.05
C:\Users\xxx\Desktop\LSTM.onnx
get data from LSTM onnx:
0.000401703 0.00102207 0.0011015 -0.000503412 -0.000911839 -0.0011367 -0.000309185 0.000591398 -0.000362981 -4.81475e-05