Pytorch多batch导出ONNX模型

b, h, w, c = model.shape
str_w = str(w)
str_h = str(h)
str_c = str(c)
dynamic_axes = {'input': {0: 'batch', 1: str_h, 2: str_w, 3: str_c}}                   
torch.onnx.export(model,  							# model being run
                  x,  								# model input (or a tuple for multiple inputs)
                  "model.onnx",  					# where to save the model (can be a file or file-like object)
                  export_params=True,  				# store the trained parameter weights inside the model file
                  opset_version=11,  				# the ONNX version to export the model to
                  do_constant_folding=True, 	    # whether to execute constant folding for optimization
                  input_names=['input'],  			# the model's input names
                  output_names=['output'],  		# the model's output names
                  dynamic_axes=dynamic_axes)

你可能感兴趣的:(Pytorch,onnx.export)