ONNX是facebook AI部门那帮人搞出来的东西,可以方便的把pytorch定义训练好的模型转换到caffe2,然后就可以进行部署,尤其是可以部署到移动端。想想,刚刚训练好的pytorch马上就可以部署到android上,是不是很激动~
import io
import numpy as np
from torch import nn
from torch.autograd import Variable
import torch.utils.model_zoo as model_zoo
import torch.onnx
# model definition
import torch.nn as nn
import torch.nn.init as init
class SRnet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SRnet, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(1, 64, (5,5),(1,1),(2,2))
self.conv2 = nn.Conv2d(64, 64, (3,3),(1,1),(1,1))
self.conv3 = nn.Conv2d(64, 32, (3,3),(1,1),(1,1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3,3),(1,1),(1,1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv4.weight)
torch_model = SRnet(upscale_factor=3)
# load pretrained model
map_location = lambda storage, loc : storage # load to cpu
state_dict = torch.load('sr.pth', map_location=map_location)
torch_model.load_state_dict(state_dict)
torch_model.train(False)
SRnet (
(relu): ReLU (inplace)
(conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pixel_shuffle): PixelShuffle (upscale_factor=3)
)
上面这些内容,我们就完成了pytorch端的工作,接下来就要把它导入到caffe2中。这个工作叫做“tracing”,具体实现的方式是,提供一个x,让x把整个网络forward一遍,这个过程会记录用到了哪些torch提供的operator。
batch_size = 1
x = Variable(torch.randn(batch_size, 1, 244, 244), requires_grad=True)
torch_out = torch.onnx._export(torch_model,
x,
"super_resolution.onnx",
export_params=True)
torch_out输出没有什么特殊的用途,不过可以用来验证,pytorch和caffe2得到相同的结果,导出的模型存为文件,“super_resolution.onnx”。
import onnx
import onnx_caffe2.backend
# graph is a python protobuf object
# for different export dl platform, caffe2, cntk, mxnet, tf
# they all use protobuf object
graph = onnx.load("super_resolution.onnx")
prepared_backend = onnx_caffe2.backend.prepare(graph)
img_input = {graph.input[0]: x.data.numpy()}
c2_out = prepared_backend.run(img_input)[0]
np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
到这里呢,我们成功地把pytorch定义的模型以及训练的参数,在caffe2框架下跑起来的。有没有觉得很6呢?
cpp使用官方提供的speed_benchmark.cc这个例程。我们先生成这个cpp代码所需要的模型的文件。
c2_workspace = prepared_backend.workspace
c2_graph = prepared_backend.predict_net
from caffe2.python.predictor import mobile_exporter
init_net, predict_net = mobile_exporter.Export(c2_workspace, c2_graph, c2_graph.external_input)
with open('init_net.pd', 'wb') as f:
f.write(init_net.SerializeToString())
with open('predict_net.pb', 'wb') as f:
f.write(predict_net.SerializeToString())
可以看到,文件夹下面生成了init_net.pb和predict_net.pb,第一个文件是模型的参数文件,第二个文件是模型的定义文件。为什么这样呢?把模型的定义存为文件,这样模型的文件就是和平台无关了,pytorch和caffe2都可以使用这个文件,python和cpp代码都能使用这个文件,ubuntu和android也都能使用这个文件。
# Run on caffe2_python
from caffe2.proto import caffe2_pb2
from caffe2.python import core, net_drawer, net_printer, visualize, workspace,utils
import numpy as np
import os
import subprocess
from PIL import Image
from skimage import io, transform
img = Image.open('./cat_244x244.jpg')
img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()
workspace.RunNetOnce(init_net)
workspace.RunNetOnce(predict_net)
print(net_printer.to_string(predict_net))
# net: torch-jit-export
11 = Conv(1, 2, kernels=[5L, 5L], strides=[1L, 1L], pads=[2L, 2L, 2L, 2L], dilations=[1L, 1L], group=1)
12 = Add(11, 3, broadcast=1, axis=1)
13 = Relu(12)
15 = Conv(13, 4, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)
16 = Add(15, 5, broadcast=1, axis=1)
17 = Relu(16)
19 = Conv(17, 6, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)
20 = Add(19, 7, broadcast=1, axis=1)
21 = Relu(20)
23 = Conv(21, 8, kernels=[3L, 3L], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], dilations=[1L, 1L], group=1)
24 = Add(23, 9, broadcast=1, axis=1)
25, _onnx_dummy1 = Reshape(24, shape=[1L, 1L, 3L, 3L, 244L, 244L])
26 = Transpose(25, axes=[0L, 1L, 4L, 2L, 5L, 3L])
27, _onnx_dummy2 = Reshape(26, shape=[1L, 1L, 732L, 732L])
# feed input
workspace.FeedBlob('1', np.array(img_y)[np.newaxis, np.newaxis, :,:].astype(np.float32))
# forward net
workspace.RunNetOnce(predict_net)
# fetch output
img_out = workspace.FetchBlob('27')
# save output to image
img_out_y = Image.fromarray(np.uint8(img_out[0,0]).clip(0,255), mode='L')
final_img = Image.merge(
'YCbCr', [
img_out_y,
img_cb.resize(img_out_y.size, Image.BICUBIC),
img_cr.resize(img_out_y.size, Image.BICUBIC),
]).convert('RGB')
final_img.save('./cat_superres.jpg')
# prepare input blob
with open('input.blobproto', 'wb') as f:
f.write(workspace.SerializeBlob('1'))
我们需要编译cpp的代码,使用如下编译命令:
CAFFE2_ROOT=$HOME/src/caffe2
g++ speed_benchmark.cc -o demo -std=c++11 \
-I $CAFFE2_ROOT/third_party/eigen \
-lCaffe2_CPU \
-lglog \
-lgflags \
-lprotobuf \
-lpthread \
-llmdb \
-lleveldb \
-lopencv_core \
-lopencv_highgui \
-lopencv_imgproc
能够使用这条命令的前提是,caffe2安装到了/usr/local
下,使用了sudo make install
进行安装。
运行cpp程序:
./demo --init_net init_net.pd --net predict_net.pb --input 1 --input_file input.blobproto --output_folder . --output 27 --iter 1
等一下,我们先看一眼speed_benchmark.cc
#include
#include "caffe2/core/init.h"
#include "caffe2/core/operator.h"
#include "caffe2/proto/caffe2.pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"
#include "caffe2/core/logging.h"
// 定义args
CAFFE2_DEFINE_string(net, "", "The given net to benchmark.");
CAFFE2_DEFINE_string(init_net, "",
"The given net to initialize any parameters.");
CAFFE2_DEFINE_string(input, "",
"Input that is needed for running the network. If "
"multiple input needed, use comma separated string.");
CAFFE2_DEFINE_string(input_file, "",
"Input file that contain the serialized protobuf for "
"the input blobs. If multiple input needed, use comma "
"separated string. Must have the same number of items "
"as input does.");
CAFFE2_DEFINE_string(input_dims, "",
"Alternate to input_files, if all inputs are simple "
"float TensorCPUs, specify the dimension using comma "
"separated numbers. If multiple input needed, use "
"semicolon to separate the dimension of different "
"tensors.");
CAFFE2_DEFINE_string(output, "",
"Output that should be dumped after the execution "
"finishes. If multiple outputs are needed, use comma "
"separated string. If you want to dump everything, pass "
"'*' as the output value.");
CAFFE2_DEFINE_string(output_folder, "",
"The folder that the output should be written to. This "
"folder must already exist in the file system.");
CAFFE2_DEFINE_int(warmup, 0, "The number of iterations to warm up.");
CAFFE2_DEFINE_int(iter, 10, "The number of iterations to run.");
CAFFE2_DEFINE_bool(run_individual, false, "Whether to benchmark individual operators.");
using std::string;
using std::unique_ptr;
using std::vector;
int main(int argc, char** argv) {
caffe2::GlobalInit(&argc, &argv);
unique_ptr workspace(new caffe2::Workspace());
// 读取模型参数到工作空间
caffe2::NetDef net_def;
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_init_net, &net_def));
CAFFE_ENFORCE(workspace->RunNetOnce(net_def));
// 加载输入数据,提供两种方式,--input和--input_dims
if (caffe2::FLAGS_input.size()) {
vector<string> input_names = caffe2::split(',', caffe2::FLAGS_input);
if (caffe2::FLAGS_input_file.size()) {
vector<string> input_files = caffe2::split(',', caffe2::FLAGS_input_file);
CAFFE_ENFORCE_EQ(
input_names.size(), input_files.size(),
"Input name and file should have the same number.");
for (int i = 0; i < input_names.size(); ++i) {
caffe2::BlobProto blob_proto;
CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto));
workspace->CreateBlob(input_names[i])->Deserialize(blob_proto);
}
} else if (caffe2::FLAGS_input_dims.size()) {
vector<string> input_dims_list = caffe2::split(';', caffe2::FLAGS_input_dims);
CAFFE_ENFORCE_EQ(
input_names.size(), input_dims_list.size(),
"Input name and dims should have the same number of items.");
for (int i = 0; i < input_names.size(); ++i) {
vector<string> input_dims_str = caffe2::split(',', input_dims_list[i]);
vector<int> input_dims;
for (const string& s : input_dims_str) {
input_dims.push_back(caffe2::stoi(s));
}
caffe2::TensorCPU* tensor =
workspace->GetBlob(input_names[i])->GetMutable();
tensor->Resize(input_dims);
tensor->mutable_data<float>();
}
} else {
CAFFE_THROW("You requested input tensors, but neither input_file nor "
"input_dims is set.");
}
}
// 加载模型定义文件,创建模型,
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def));
caffe2::NetBase* net = workspace->CreateNet(net_def);
CHECK_NOTNULL(net);
net->TEST_Benchmark(
caffe2::FLAGS_warmup,
caffe2::FLAGS_iter,
caffe2::FLAGS_run_individual);
// 获得输出
string output_prefix = caffe2::FLAGS_output_folder.size()
? caffe2::FLAGS_output_folder + "/"
: "";
if (caffe2::FLAGS_output.size()) {
vector<string> output_names = caffe2::split(',', caffe2::FLAGS_output);
if (caffe2::FLAGS_output == "*") {
output_names = workspace->Blobs();
}
for (const string& name : output_names) {
CAFFE_ENFORCE(
workspace->HasBlob(name),
"You requested a non-existing blob: ",
name);
string serialized = workspace->GetBlob(name)->Serialize(name);
string output_filename = output_prefix + name;
caffe2::WriteStringToFile(serialized, output_filename.c_str());
}
}
return 0;
}
程序运行的结果是,生成了一个27
文件,我们用python把这个文件转换为jpg。
blob_proto = caffe2_pb2.BlobProto()
blob_proto.ParseFromString(open('./27_mobile').read())
img_out = utils.Caffe2TensorToNumpyArray(blob_proto.tensor)
img_out_y = Image.fromarray(np.uint8((img_out[0,0]).clip(0,255)), mode='L')
final_img = Image.merge(
"YCbCr", [
img_out_y,
img_cb.resize(img_out_y.size, Image.BICUBIC),
img_cr.resize(img_out_y.size, Image.BICUBIC),
]).convert('RGB')
final_img.save('./cat_superres_mobile.jpg')
按照原教程的做法,可以顺利运行。这里写几点值得注意的点:
- android编译的可执行程序是静态编译,以便方便地在手机设备上运行。
- android程序并不是一定要用java写,这个例子便是用cpp写的,这也表明,android-cmake
这个项目真得很方便。
- 不过,如果仅仅在android上运行控制台程序,那也太不爽了,那我还要这嵌入式设备做啥子呢?所以,最后一定要是放到一个有界面的app中运行,这个就要用原生的java
写了。
- 关于android环境配置,最快的方式莫过于,装一个android studio
.