TenserRT(八): TensorRT C++实战推理

#include 
#include 

#include 
#include <../common/logger.h>
#include 
#include "../common/util.h"
#include 
#include 
#include "../volume.h"
#include "../xt_pointer.h"
#include "../volume_resample.h"
#include 
#include 
#include 
#define CHECK(status) \
    do\
    {\
        auto ret = (status);\
        if (ret != 0)\
        {\
            std::cerr << "Cuda failure: " << ret << std::endl;\
            abort();\
        }\
    } while (0)

using namespace nvinfer1;
using namespace sample;

static const double M_MIN = 1e-8;
const char* IN_NAME = "input";
const char* OUT_NAME = "output";
static const int IN_CLS = 35;
static const int BATCH_SIZE = 1;
static const int EXPLICIT_BATCH = 1 << (int)(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
/*****************************************************************************************/

std::vector Split(std::string s, std::string delimiter)
{
	std::vector ret;
	size_t pos = 0;
	std::string token;
	while ((pos = s.find(delimiter)) != std::string::npos) {
		token = s.substr(0, pos);
		//std::cout << token << std::endl;void
		if (token.length() > 0) ret.push_back(token);
		s.erase(0, pos + delimiter.length());
	}
	if (s.length() > 0)
		ret.push_back(s);
	//std::cout << s << std::endl;
	return ret;
};
static void GetConfigStringValue(std::string config_file, const char* keyName, char* keyValue)
{
	char buff[150] = { 0 };
	FILE* file = fopen(config_file.c_str(), "r");
	while (fgets(buff, 100, file))
	{
		char* tempKeyName = strtok(buff, "=");
		if (!tempKeyName) continue;
		//std::cout << tempKeyName << std::endl;

		char* tempKeyValue = strtok(NULL, "=");

		if (!strcmp(tempKeyName, keyName))
			strcpy(keyValue, tempKeyValue);
	}
	std::cout << keyName << ": " << keyValue << std::endl;
	fclose(file);
}
template 
void load_volume_mhd(std::string dir, std::string name, DLL::PTR> vol)
{
	DLL::Origin origin;
	DLL::Spacing spacing;
	DLL::Size size;
	std::string lfn = dir + "/" + name + ".mhd";
	std::cout << lfn << std::endl;

	char origin_str[100] = { 0 };
	GetConfigStringValue(lfn, "Offset ", origin_str);
	auto origin_vec = Split(origin_str, " ");
	origin[0] = std::stoi(origin_vec[0]);
	origin[1] = std::stoi(origin_vec[1]);
	origin[2] = std::stoi(origin_vec[2]);


	char spacing_str[100] = { 0 };
	GetConfigStringValue(lfn, "ElementSpacing ", spacing_str);
	auto spacing_vec = Split(spacing_str, " ");
	spacing[0] = std::stof(spacing_vec[0]);
	spacing[1] = std::stof(spacing_vec[1]);
	spacing[2] = std::stof(spacing_vec[2]);

	char size_str[100] = { 0 };
	GetConfigStringValue(lfn, "DimSize ", size_str);
	auto size_vec = Split(size_str, " ");
	size[0] = std::stoi(size_vec[0]);
	size[1] = std::stoi(size_vec[1]);
	size[2] = std::stoi(size_vec[2]);
	size_t length = size[2] * size[1] * size[0];

	vol->Create(size, spacing, origin);
	std::string fn = dir + "/" + name + ".raw";
	FILE* pf = fopen(fn.c_str(), "rb");
	fread(vol->get_voxels(), sizeof(T), length, pf);
	fclose(pf);

}
template 
void save_volume_mhd(DLL::PTR> vol, std::string dir, std::string name)
{
	auto origin = vol->get_origin();
	auto spacing = vol->get_spacing();
	auto size = vol->get_size();

	std::string lfn = dir + "/" + name + ".mhd";
	std::cout << lfn << std::endl;

	std::fstream l_File;
	l_File.open(lfn, std::ios::binary | std::ios::out);
	l_File << "ObjectType = Image" << std::endl;
	l_File << "NDims = 3" << std::endl;
	l_File << "BinaryData = True" << std::endl;
	l_File << "BinaryDataByteOrderMSB = False" << std::endl;
	l_File << "CompressedData = False" << std::endl;
	l_File << "TransformMatrix = 1 0 0 0 1 0 0 0 1" << std::endl;
	l_File << "Offset = " << origin[0] << " " << origin[1] << " " << origin[2] << std::endl;
	l_File << "CenterOfRotation = 0 0 0" << std::endl;
	l_File << "ElementSpacing = " << spacing[0] << " " << spacing[1] << " " << spacing[2] << std::endl;
	l_File << "DimSize = " << size[0] << " " << size[1] << " " << size[2] << std::endl;

	if (typeid(T) == typeid(short))
	{
		l_File << "ElementType = MET_SHORT" << std::endl;
	}
	if (typeid(T) == typeid(float))
	{
		l_File << "ElementType = MET_FLOAT" << std::endl;
	}

	l_File << "ElementDataFile =" << name + ".raw" << std::endl;
	if (l_File.is_open())
		l_File.close();
	std::string fn = dir + "/" + name + ".raw";
	FILE* pf = fopen(fn.c_str(), "wb");

	fwrite(vol->get_voxels(), sizeof(T), size[2] * size[1] * size[0], pf);
	fclose(pf);
}

template 
DLL::PTR< DLL::Volume> convert2type(DLL::PTR> vol)
{
	DLL::PTR< DLL::Volume> r_vol = PTR_NEW(DLL::Volume);
	r_vol->Create(vol->get_size(), vol->get_spacing(), vol->get_origin());
	T1* voxels = vol->get_voxels();
	T2* rvoxels = r_vol->get_voxels();
	concurrency::parallel_for(0, r_vol->get_nvox(), [&](DLL::uint32 i)
		{
			rvoxels[i] = (T2)voxels[i];
		});
	return r_vol;
}
/*****************************************************************************************/
template 
T clip(T value, T minValue, T maxValue) {
	return std::max(minValue, std::min(value, maxValue));
}
template 
void  normalize(DLL::PTR> vol_in, DLL::PTR< DLL::Volume> vol_out, float mean, float sd, float lb, float ub)
{
	DLL::uint32 out_nvox = vol_out->get_nvox();
	float* out_pointer = vol_out->get_voxels();
	T* in_pointer = vol_in->get_voxels();
	concurrency::parallel_for(DLL::uint32(0), out_nvox, [&](DLL::uint32 out_idx)
		//for (uint32 out_idx = 0; out_idx < out_nvox; ++out_idx)
		{
			out_pointer[out_idx] = (clip((float)in_pointer[out_idx], lb, ub) - mean) / sd;
		}
	);
}

template 
void  argmax(const float* v, T* out, DLL::int32 dim, int N)
{
	for (int i = 0; i < dim; i++)
	{
		int max_cls = 0;
		int max_pro = v[i];
		for (int cls = 1; cls < N; cls++)
		{
			if (max_pro< v[i + cls * dim])
			{
				max_pro = v[i + cls * dim];
				max_cls = cls;
			}
		}
		out[i] = max_cls;
	}
}


template 
void  get_bbox(DLL::PTR> vol, std::vector& i_vec, std::vector& j_vec, std::vector& k_vec)
{
	DLL::uint32 nvox = vol->get_nvox();
	DLL::Size   size = vol->get_size();
	T*        pointer = vol->get_voxels();
	i_vec = {size[0], 0};
	j_vec = {size[1], 0};
	k_vec = {size[2], 0};
	concurrency::parallel_for(DLL::uint32(0), nvox, [&](DLL::uint32 idx)
		{
			if (pointer[idx]>0)
			{
				glm::ivec3 ijk;
				COORDS_FROM_INDEX(ijk, idx, size);
				i_vec[0] = i_vec[0] < ijk[0] ? i_vec[0] : ijk[0];
				i_vec[1] = i_vec[1] > ijk[0] ? i_vec[1] : ijk[0];

				j_vec[0] = j_vec[0] < ijk[1] ? j_vec[0] : ijk[1];
				j_vec[1] = j_vec[1] > ijk[1] ? j_vec[1] : ijk[1];

				k_vec[0] = k_vec[0] < ijk[2] ? k_vec[0] : ijk[2];
				k_vec[1] = k_vec[1] > ijk[2] ? k_vec[1] : ijk[2];
			}
		}
	);
}
template 
void  padding(DLL::PTR> vol, DLL::Size patch_size, DLL::PTR> pad_vol)
{
	DLL::Size    old_size    = vol->get_size();
	DLL::Origin  old_origin  = vol->get_origin();
	DLL::Spacing old_spacing = vol->get_spacing();
	DLL::Size new_shape = old_size;
	new_shape[0] = new_shape[0] > patch_size[0] ? new_shape[0] : patch_size[0];
	new_shape[1] = new_shape[1] > patch_size[1] ? new_shape[1] : patch_size[1];
	new_shape[2] = new_shape[2] > patch_size[2] ? new_shape[2] : patch_size[2];
	DLL::Size difference = new_shape - old_size;
	DLL::Size pad_below = difference / 2;
	DLL::Size pad_above = difference / 2 + difference % 2;

	DLL::Origin new_origin = old_origin + glm::vec3(pad_below) * old_spacing;
	pad_vol->Create(new_shape, old_spacing, new_origin);
	DLL::volume_resample_linear(vol.get(), pad_vol.get());

}

double gaussianFunction(double x, double mean, double stddev) {
	return (1.0 / (stddev * std::sqrt(2 * M_PI))) * std::exp(-0.5 * std::pow((x - mean) / stddev, 2));
}

void generateGaussianKernel(DLL::PTR> kernel) {

	DLL::Size size = kernel->get_size();

	glm::ivec3 center = size / 2;
	glm::ivec3 sigmas = size / 8;

	DLL::uint32 nvox = kernel->get_nvox();
	float* pointer = kernel->get_voxels();
	concurrency::parallel_for(DLL::uint32(0), nvox, [&](DLL::uint32 idx)
		{
			glm::ivec3 ijk;
			COORDS_FROM_INDEX(ijk, idx, size);
			auto dis = ijk - center;
			pointer[idx]  = gaussianFunction(dis[0], 0, sigmas[0]) *
				gaussianFunction(dis[1], 0, sigmas[1]) *
				gaussianFunction(dis[2], 0, sigmas[2]);
		}
	);
	double max = kernel->get_max_value();
	concurrency::parallel_for(DLL::uint32(0), nvox, [&](DLL::uint32 idx)
		{
			pointer[idx] /= max;
			pointer[idx] += M_MIN;
		}
	);
}

std::vector range(int start, int end, double step)
{
	std::vectorres;
	for (int i = start; i < end + M_MIN; i += step + 1) res.push_back(round(i));
	if (res.back()!=end)
	{
		res.push_back(end);
	}
	return res;
}
template 
void  fill_patch(DLL::PTR> vol, DLL::PTR> patch)
{
	DLL::uint32  in_nvox = patch->get_nvox();
	DLL::Size    in_dim = patch->get_size();
	T*         in_pointer = patch->get_voxels();

	DLL::uint32 out_nvox = vol->get_nvox();
	DLL::Size   out_dim = vol->get_size();
	T*		  out_pointer = vol->get_voxels();
	T value;
	concurrency::parallel_for(DLL::uint32(0), out_nvox, [&](DLL::uint32 out_idx)
		{
			DLL::uint32 in_idx;
			glm::ivec3 in_ijk;
			glm::ivec3 out_ijk;
			glm::vec3 xyz;
			COORDS_FROM_INDEX(out_ijk, out_idx, out_dim);
			xyz = vol->volume2patient(out_ijk);
			in_ijk = patch->patient2continusvolume(xyz);
			if (patch->is_inside(in_ijk)) {
				value = patch->get_value(in_ijk);
				out_pointer[out_idx] = value;				
			}
		}
	);
}
template 
void  fill_patch(T* out_pointer, T* guass_pointer, T* in_pointer, T* mu_pointer, std::vector out_size, std::vector in_size,  std::vector& i_vec, std::vector& j_vec, std::vector& k_vec)
{

	int inB, outB;
	int inC, outC;
	int inD, outD;
	int inH, outH;
	int inW, outW;
	inB = in_size[0], outB = out_size[0];
	inC = in_size[1], outC = out_size[1];
	inD = in_size[2], outD = out_size[2];
	inH = in_size[3], outH = out_size[3];
	inW = in_size[4], outW = out_size[4];
	DLL::uint32 inDim  = inD * inH * inW;
	DLL::uint32 outDim = outD * outH * outW ;

	//std::cout << "in_size : [" << in_size[0] << ", " << in_size[1] << ", " << in_size[2] << ", " << in_size[3] << ", " << in_size[4] << "]" << std::endl;
	//std::cout << "out_size : [" << out_size[0] << ", " << out_size[1] << ", " << out_size[2] << ", " << out_size[3] << ", " << out_size[4] << "]" << std::endl;

	//for (DLL::uint32 in_idx = 0; in_idx < inB * inC * inDim; in_idx++)
	concurrency::parallel_for(DLL::uint32(0), inB * inC * inDim, [&](DLL::uint32 in_idx)
		{

			int B = in_idx / (inC * inDim);

			int C = (in_idx - B * inC * inDim) / inDim;
			int D = (in_idx - B * inC * inDim - C * inDim) / (inH * inW);
			int H = (in_idx - B * inC * inDim - C * inDim - D * inH * inW) / inW;
			int W = (in_idx - B * inC * inDim - C * inDim - D * inH * inW) % inW;
			//std::cout << "dim: [" << B <<", " << C << ", " << D << ", " << H << ", " << W << "]" << std::endl;
			DLL::uint32 out_idx = B * inC * outDim + C * outDim + (D + k_vec[0]) * (outH * outW) + (H + j_vec[0]) * outW + (W + i_vec[0]);
			DLL::uint32 mu_idx = D * (inH * inW) + H * inW + W;

			out_pointer[out_idx] += in_pointer[in_idx] * mu_pointer[mu_idx];
			guass_pointer[out_idx] += mu_pointer[mu_idx];
			
		});
}

template 
void  patch_normalize(T* out_pointer, T* in_pointer, DLL::uint32 nvox)
{
	concurrency::parallel_for(DLL::uint32(0), nvox, [&](DLL::uint32 idx)
		//for (uint32 out_idx = 0; out_idx < out_nvox; ++out_idx)
		{
			out_pointer[idx] = out_pointer[idx] / in_pointer[idx];
		}
	);
}

/*****************************************************************************************/
void doInference(IExecutionContext& context, DLL::PTR> pad_crop_vol, DLL::PTR> pred_crop_res,  DLL::PTR> kernel, DLL::Size patch_size,  int batchSize)
{
	//step
	int step = 2;
	int DIM = patch_size[0] * patch_size[1] * patch_size[2];

	DLL::Size center_coord_start = patch_size / 2;
	DLL::Size center_coord_end = pad_crop_vol->get_size() - center_coord_start;
	glm::ivec3 num_steps = glm::ceil(glm::dvec3(center_coord_end - center_coord_start) / glm::dvec3(patch_size / step));

	glm::dvec3 step_size = glm::dvec3(center_coord_end - center_coord_start) / (glm::dvec3(num_steps) + M_MIN);

	std::cout << "center_coord_start :[" << center_coord_start[0] << ", " << center_coord_start[1] << ", " << center_coord_start[2] << "]" << std::endl;
	std::cout << "center_coord_end :[" << center_coord_end[0] << ", " << center_coord_end[1] << ", " << center_coord_end[2] << "]" << std::endl;
	std::cout << "num_steps :[" << num_steps[0] << ", " << num_steps[1] << ", " << num_steps[2] << "]" << std::endl;
	std::cout << "step_size :[" << step_size[0] << ", " << step_size[1] << ", " << step_size[2] << "]" << std::endl;

	std::vector ksteps = range(center_coord_start[2], center_coord_end[2], step_size[2]);
	std::vector jsteps = range(center_coord_start[1], center_coord_end[1], step_size[1]);
	std::vector isteps = range(center_coord_start[0], center_coord_end[0], step_size[0]);

	DLL::Size rsp_pad_origin = pad_crop_vol->get_origin();
	DLL::Spacing current_spacing = pad_crop_vol->get_spacing();

	DLL::PTR> patch_vol(new DLL::Volume);
	patch_vol->Create(patch_size, current_spacing, rsp_pad_origin);

	/***********************************************************************/
	int Dim = patch_size[0] * patch_size[1] * patch_size[2];
	float* prob;
	if ((prob = (float*)malloc(BATCH_SIZE * IN_CLS * Dim * sizeof(float))) == NULL)
	{
		printf("malloc error\n");
		return ;
	}
	std::vector p_size = { BATCH_SIZE, IN_CLS , patch_size[2], patch_size[1], patch_size[0] };

	DLL::Size rsp_pad_size = pad_crop_vol->get_size();
	int resDim = rsp_pad_size[0] * rsp_pad_size[1] * rsp_pad_size[2];
	float* result;
	if ((result = (float*)malloc(BATCH_SIZE * IN_CLS * resDim * sizeof(float))) == NULL)
	{
		printf("malloc error\n");
		return ;
	}
	memset(result, 0, BATCH_SIZE * IN_CLS * resDim * sizeof(float)); // 使用 memset 将内存初始化为 0
	std::vector v_size = { BATCH_SIZE, IN_CLS , rsp_pad_size[2], rsp_pad_size[1], rsp_pad_size[0] };

	float* result_numsamples;
	if ((result_numsamples = (float*)malloc(BATCH_SIZE * IN_CLS * resDim * sizeof(float))) == NULL)
	{
		printf("malloc error\n");
		return ;
	}
	memset(result_numsamples, 0, BATCH_SIZE * IN_CLS * resDim * sizeof(float));
	/***********************************************************************/
	const ICudaEngine& engine = context.getEngine();

	// Pointers to input and output device buffers to pass to engine.
	// Engine requires exactly IEngine::getNbBindings() number of buffers.
	assert(engine.getNbBindings() == 2);
	void* buffers[2];

	// In order to bind the buffers, we need to know the names of the input and output tensors.
	// Note that indices are guaranteed to be less than IEngine::getNbBindings()
	const int inputIndex = engine.getBindingIndex(IN_NAME);
	const int outputIndex = engine.getBindingIndex(OUT_NAME);
	printf("inputindex:%d\n", inputIndex);
	printf("outputIndex:%d\n", outputIndex);

	// Create GPU buffers on device
	CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 1 * DIM * sizeof(float)));
	CHECK(cudaMalloc(&buffers[outputIndex], batchSize * IN_CLS * DIM * sizeof(float)));

	// Create stream
	cudaStream_t stream;
	CHECK(cudaStreamCreate(&stream));	

	for (auto k : ksteps)
	{
		int lb_k = k - patch_size[2] / 2;
		int	ub_k = k + patch_size[2] / 2;
		for (auto j : jsteps)
		{
			int lb_j = j - patch_size[1] / 2;
			int	ub_j = j + patch_size[1] / 2;
			for (auto i : isteps)
			{
				int lb_i = i - patch_size[0] / 2;
				int	ub_i = i + patch_size[0] / 2;
				DLL::Origin patch_origin = pad_crop_vol->volume2patient({ lb_i, lb_j, lb_k });
				patch_vol->set_origin(patch_origin);
				patch_vol->RefreshMatrix();
				DLL::volume_resample_linear(pad_crop_vol.get(), patch_vol.get());

				// DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
				CHECK(cudaMemcpyAsync(buffers[inputIndex], patch_vol->get_voxels(), batchSize * 1 * DIM * sizeof(float), cudaMemcpyHostToDevice, stream));
				//context.enqueue(batchSize, buffers, stream, nullptr);
				context.enqueueV2(buffers, stream, nullptr);
				CHECK(cudaMemcpyAsync(prob, buffers[outputIndex], batchSize * IN_CLS * DIM * sizeof(float), cudaMemcpyDeviceToHost, stream));

				fill_patch(result, result_numsamples, prob, kernel->get_voxels(), v_size, p_size, std::vector({ lb_i ,ub_i }), std::vector({ lb_j, ub_j }), std::vector({ lb_k, ub_k }));

				std::cout << "predict: [" << lb_k << " " << lb_j << " " << lb_i << "] - [" << ub_k << " " << ub_j << " " << ub_i << "]" << std::endl;

			}
		}
	}

	cudaStreamSynchronize(stream);

	// Release stream and buffers
	cudaStreamDestroy(stream);
	CHECK(cudaFree(buffers[inputIndex]));
	CHECK(cudaFree(buffers[outputIndex]));

	patch_vol->Clear();
	patch_normalize(result, result_numsamples, BATCH_SIZE * IN_CLS * resDim);
	delete result_numsamples;
	argmax(result, pred_crop_res->get_voxels(), rsp_pad_size[0] * rsp_pad_size[1] * rsp_pad_size[2], IN_CLS);
	delete result;


}
/*****************************************************************************************/
int main(int argc, char** argv)
{

	float mean = 71.05476;
	float sd = 272.33304;
	float lb = -953.0;
	float ub = 1510.0;
	DLL::Spacing current_spacing = {  1., 1., 2.5 };
	DLL::Size patch_size = {  224, 128, 64};
	/***********************************************************************/
	std::string dir = "C:/Users/datu/Desktop/validation/trt/cpp";
	// load ct
	DLL::PTR> ori_vol(new DLL::Volume);
	load_volume_mhd(dir, "ct_volume", ori_vol);
	
	 //load body
	DLL::PTR> body_vol(new DLL::Volume);
	load_volume_mhd(dir, "body_itk", body_vol);	
	//save_volume_mhd(ori_vol, dir, "ct_volume_test");	
	auto ori_origin = ori_vol->get_origin();
	auto ori_spacing = ori_vol->get_spacing();
	auto ori_size = ori_vol->get_size();
	/***********************************************************************/
	auto start = std::chrono::high_resolution_clock::now();

	//get bbox
	std::vector i_vec, j_vec, k_vec;
	get_bbox(body_vol, i_vec, j_vec, k_vec);
	std::cout << "i [" << i_vec[0] << ", " << i_vec[1] + 1 << "]" << std::endl;
	std::cout << "j [" << j_vec[0] << ", " << j_vec[1] + 1 << "]" << std::endl;
	std::cout << "k [" << k_vec[0] << ", " << k_vec[1] + 1 << "]" << std::endl;
	/***********************************************************************/
	//crop volume
	DLL::PTR> crop_vol(new DLL::Volume);

	DLL::Origin crop_origin = ori_vol->volume2patient({ i_vec[0], j_vec[0], k_vec[0] });
	DLL::Size crop_size = { i_vec[1]- i_vec[0]+1, j_vec[1] - j_vec[0] + 1, k_vec[1] - k_vec[0] + 1 };
	DLL::Size rsp_crop_size = glm::round(glm::dvec3(ori_spacing) / glm::dvec3(current_spacing) * glm::dvec3(crop_size));
	crop_vol->Create(rsp_crop_size, current_spacing, crop_origin);

	DLL::volume_resample_linear(ori_vol.get(), crop_vol.get());
	std::cout << "rsp_crop_size :[" << rsp_crop_size[0] << ", " << rsp_crop_size[1] << ", " << rsp_crop_size[2] << "]" << std::endl;
	//save_volume_mhd(rsp_crop_vol, dir, "rsp_crop_vol_test");
	/***********************************************************************/
	//normalize crop volume
	DLL::PTR> nor_crop_vol(new DLL::Volume);
	nor_crop_vol->Create(rsp_crop_size, current_spacing, crop_origin);
	normalize(crop_vol, nor_crop_vol, mean, sd, lb, ub);
	//save_volume_mhd(nor_crop_vol, dir, "nor_crop_vol");
	/***********************************************************************/
	//padding crop volume
	DLL::PTR> pad_crop_vol(new DLL::Volume);
	padding(nor_crop_vol, patch_size, pad_crop_vol);
	//save_volume_mhd(pad_crop_vol, dir, "pad_crop_vol");
	/***********************************************************************/
	// gaussian
	DLL::PTR> kernel(new DLL::Volume);
	kernel->Create(patch_size);
	generateGaussianKernel(kernel);
	//save_volume_mhd(kernel, dir, "kernel_vol");
	/***********************************************************************/
	char* trtModelStream{ nullptr };
	size_t size{ 0 };
	std::ifstream file("F:/localwork/cpp/trtDemo/output/Debug/hn_model.engine", std::ios::binary);
	if (file.good()) {
		file.seekg(0, file.end);
		size = file.tellg();
		file.seekg(0, file.beg);
		trtModelStream = new char[size];
		assert(trtModelStream);
		file.read(trtModelStream, size);
		file.close();
	}
	Logger m_logger;
	IRuntime* runtime = createInferRuntime(m_logger);
	assert(runtime != nullptr);
	initLibNvInferPlugins(&m_logger, "");
	ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr);
	assert(engine != nullptr);
	IExecutionContext* context = engine->createExecutionContext();
	assert(context != nullptr);

	DLL::Size rsp_pad_size = pad_crop_vol->get_size();

	DLL::PTR> pred_crop_res(new DLL::Volume);
	pred_crop_res->Create(rsp_pad_size, current_spacing, pad_crop_vol->get_origin());
	
	doInference(*context, pad_crop_vol, pred_crop_res, kernel, patch_size, BATCH_SIZE);

	DLL::PTR> pred_res(new DLL::Volume);
	pred_res->Create(ori_size, ori_spacing, ori_origin);
	fill_patch(pred_res, pred_crop_res);

	auto end = std::chrono::high_resolution_clock::now();
	std::chrono::duration duration = end - start;
	// 输出时间差(以秒为单位)
	std::cout << "运行时间: " << duration.count() << " 秒" << std::endl;


	save_volume_mhd(pred_res, dir, "pred_res");
	pred_res->Clear();
	

	//std::memcpy(pred_patch->get_voxels(), output, Dim * sizeof(float));
	// Destroy the engine
	context->destroy();
	engine->destroy();
	runtime->destroy();
	return 0;
}

你可能感兴趣的:(c++,算法,开发语言)