Pytorch转TensorRT范例代码

Pytorch转TensorRT范例代码

2018年10月25日 15:56:56 柳鲲鹏 阅读数:1009

  TensorRT官方文档说,/usr/src/tensorrt/samples/python/network_api_pytorch_mnist下有示例代码。实际上根本就没有。这里提供一个示例代码,供参考。

  这个范例的具体位置是:/usr/local/lib/python3.5/site-packages/tensorrt/examples/pytorch_to_trt

 
  1. #!/usr/bin/python

  2. import os

  3. from random import randint

  4. import numpy as np

  5.  
  6. try:

  7. import pycuda.driver as cuda

  8. import pycuda.gpuarray as gpuarray

  9. import pycuda.autoinit

  10. except ImportError as err:

  11. raise ImportError("""ERROR: Failed to import module({})

  12. Please make sure you have pycuda and the example dependencies installed.

  13. sudo apt-get install python(3)-pycuda

  14. pip install tensorrt[examples]""".format(err))

  15.  
  16. try:

  17. from PIL import Image

  18. except ImportError as err:

  19. raise ImportError("""ERROR: Failed to import module ({})

  20. Please make sure you have Pillow installed.

  21. For installation instructions, see:

  22. http://pillow.readthedocs.io/en/stable/installation.html""".format(err))

  23.  
  24. import mnist

  25.  
  26. try:

  27. import torch

  28. except ImportError as err:

  29. raise ImportError("""ERROR: Failed to import module ({})

  30. Please make sure you have PyTorch installed.

  31. For installation instructions, see:

  32. http://pytorch.org/""".format(err))

  33.  
  34. # TensorRT must be imported after any frameworks in the case where

  35. # the framework has incorrect dependencies setup and is not updated

  36. # to use the versions of libraries that TensorRT imports.

  37. try:

  38. import tensorrt as trt

  39. except ImportError as err:

  40. raise ImportError("""ERROR: Failed to import module ({})

  41. Please make sure you have the TensorRT Library installed

  42. and accessible in your LD_LIBRARY_PATH""".format(err))

  43.  
  44. G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)

  45.  
  46. ITERATIONS = 10

  47. INPUT_LAYERS = ["data"]

  48. OUTPUT_LAYERS = ['prob']

  49. INPUT_H = 28

  50. INPUT_W = 28

  51. OUTPUT_SIZE = 10

  52.  
  53. def create_pytorch_engine(max_batch_size, builder, dt, model):

  54. network = builder.create_network()

  55.  
  56. data = network.add_input(INPUT_LAYERS[0], dt, (1, INPUT_H, INPUT_W))

  57. assert(data)

  58.  
  59. #-------------

  60. conv1_w = model['conv1.weight'].cpu().numpy().reshape(-1)

  61. conv1_b = model['conv1.bias'].cpu().numpy().reshape(-1)

  62. conv1 = network.add_convolution(data, 20, (5,5), conv1_w, conv1_b)

  63. assert(conv1)

  64. conv1.set_stride((1,1))

  65.  
  66. #-------------

  67. pool1 = network.add_pooling(conv1.get_output(0), trt.infer.PoolingType.MAX, (2,2))

  68. assert(pool1)

  69. pool1.set_stride((2,2))

  70.  
  71. #-------------

  72. conv2_w = model['conv2.weight'].cpu().numpy().reshape(-1)

  73. conv2_b = model['conv2.bias'].cpu().numpy().reshape(-1)

  74. conv2 = network.add_convolution(pool1.get_output(0), 50, (5,5), conv2_w, conv2_b)

  75. assert(conv2)

  76. conv2.set_stride((1,1))

  77.  
  78. #-------------

  79. pool2 = network.add_pooling(conv2.get_output(0), trt.infer.PoolingType.MAX, (2,2))

  80. assert(pool2)

  81. pool2.set_stride((2,2))

  82.  
  83. #-------------

  84. fc1_w = model['fc1.weight'].cpu().numpy().reshape(-1)

  85. fc1_b = model['fc1.bias'].cpu().numpy().reshape(-1)

  86. fc1 = network.add_fully_connected(pool2.get_output(0), 500, fc1_w, fc1_b)

  87. assert(fc1)

  88.  
  89. #-------------

  90. relu1 = network.add_activation(fc1.get_output(0), trt.infer.ActivationType.RELU)

  91. assert(relu1)

  92.  
  93. #-------------

  94. fc2_w = model['fc2.weight'].cpu().numpy().reshape(-1)

  95. fc2_b = model['fc2.bias'].cpu().numpy().reshape(-1)

  96. fc2 = network.add_fully_connected(relu1.get_output(0), OUTPUT_SIZE, fc2_w, fc2_b)

  97. assert(fc2)

  98.  
  99. #-------------

  100. # Using log_softmax in training, cutting out log softmax here since no log softmax in TRT

  101. fc2.get_output(0).set_name(OUTPUT_LAYERS[0])

  102. network.mark_output(fc2.get_output(0))

  103.  
  104.  
  105. builder.set_max_batch_size(max_batch_size)

  106. builder.set_max_workspace_size(1 << 20)

  107.  
  108. #builder.set_fp16_mode(True)

  109.  
  110. engine = builder.build_cuda_engine(network)

  111. network.destroy()

  112.  
  113. return engine

  114.  
  115. def model_to_engine(model, max_batch_size):

  116. builder = trt.infer.create_infer_builder(G_LOGGER)

  117. engine = create_pytorch_engine(max_batch_size, builder, trt.infer.DataType.FLOAT, model)

  118. assert(engine)

  119.  
  120. modelstream = engine.serialize()

  121. engine.destroy()

  122. builder.destroy()

  123. return modelstream

  124.  
  125. # Run inference on device

  126. def infer(context, input_img, output_size, batch_size):

  127. # Load engine

  128. engine = context.get_engine()

  129. assert(engine.get_nb_bindings() == 2)

  130. # Convert input data to Float32

  131. input_img = input_img.astype(np.float32)

  132. # Create output array to receive data

  133. output = np.empty(output_size, dtype = np.float32)

  134.  
  135. # Allocate device memory

  136. d_input = cuda.mem_alloc(batch_size * input_img.nbytes)

  137. d_output = cuda.mem_alloc(batch_size * output.nbytes)

  138.  
  139. bindings = [int(d_input), int(d_output)]

  140.  
  141. stream = cuda.Stream()

  142.  
  143. # Transfer input data to device

  144. cuda.memcpy_htod_async(d_input, input_img, stream)

  145. # Execute model

  146. context.enqueue(batch_size, bindings, stream.handle, None)

  147. # Transfer predictions back

  148. cuda.memcpy_dtoh_async(output, d_output, stream)

  149.  
  150. # Return predictions

  151. return output

  152.  
  153. def main():

  154. path = dir_path = os.path.dirname(os.path.realpath(__file__))

  155.  
  156. # The mnist package is a simple PyTorch mnist example. mnist.learn() trains a network for

  157. # PyTorch's provided mnist dataset. mnist.get_trained_model() returns the state dictionary

  158. # of the trained model. We use this to demonstrate the full training to inference pipeline

  159. mnist.learn()

  160. model = mnist.get_trained_model()

  161.  
  162. # Typically training and inference are seperated so using torch.save() and saving the

  163. # model's state dictionary and then using torch.load() to load the state dictionary

  164. #

  165. # e.g:

  166. # model = torch.load(path + "/trained_mnist.pyt")

  167. modelstream = model_to_engine(model, 1)

  168.  
  169. runtime = trt.infer.create_infer_runtime(G_LOGGER)

  170. engine = runtime.deserialize_cuda_engine(modelstream.data(), modelstream.size(), None)

  171.  
  172. if modelstream:

  173. modelstream.destroy()

  174.  
  175. img, target = mnist.get_testcase()

  176. img = img.numpy()

  177. target = target.numpy()

  178. print("\n| TEST CASE | PREDICTION |")

  179. for i in range(ITERATIONS):

  180. img_in = img[i].ravel()

  181. target_in = target[i]

  182. context = engine.create_execution_context()

  183. out = infer(context, img_in, OUTPUT_SIZE, 1)

  184. print("|-----------|------------|")

  185. print("| " + str(target_in) + " | " + str(np.argmax(out)) + " |")

  186.  
  187. print('')

  188. context.destroy()

  189. engine.destroy()

  190. runtime.destroy()

  191.  
  192.  
  193.  
  194. if __name__ == "__main__":

  195. main()

你可能感兴趣的:(CNN,卷积神经网络)