Caffe2 C++ 预测Demo

Caffe2预测代码:

bool FingerprintDNNLocating::initCaffe2(void)
{
	string mdPath = map_mgr_.getCaff2ModelPath();
	if ("" == mdPath)
		return false;
	string initNetPath = mdPath + "/init_net.pb";
	string predictNetPath = mdPath + "/predict_net.pb";
	if (!std::ifstream(initNetPath).good() ||
		!std::ifstream(predictNetPath).good()) 
	{
		return false;
	}
	CAFFE_ENFORCE(ReadProtoFromFile(initNetPath, &initNet_));
	// >>> with open(path_to_PREDICT_NET) as f:
	CAFFE_ENFORCE(ReadProtoFromFile(predictNetPath, &predictNet_));
    predictor_ = new caffe2::Predictor(initNet_, predictNet_);
	return true;
}

bool FingerprintDNNLocating::runCaffe2(vector &weights)
{
    if(!predictor_)
        return false;
    caffe2::TensorCPU input;
    input.Resize(std::vector({1, AP_COUNT}));
	input.ShareExternalPointer(model_input_data_.data());
    caffe2::Predictor::TensorVector input_vec{&input};
    caffe2::Predictor::TensorVector output_vec;
    predictor_->run(input_vec, &output_vec);
    for (auto output : output_vec) {
		for(auto i = 0; i < output->size(); ++i) {
			float val = output->template data()[i];
			rss_debug_log_<<"probs: "<

你可能感兴趣的:(C++,算法,deep,learning)